oktopus/backend/services/mtp/stomp/conn.go

775 lines
20 KiB
Go

package stomp
import (
"errors"
"io"
"net"
"strconv"
"sync"
"time"
"github.com/go-stomp/stomp/v3/frame"
)
// Default time span to add to read/write heart-beat timeouts
// to avoid premature disconnections due to network latency.
const DefaultHeartBeatError = 5 * time.Second
// Default send timeout in Conn.Send function
const DefaultMsgSendTimeout = 10 * time.Second
// Default receipt timeout in Conn.Send function
const DefaultRcvReceiptTimeout = 30 * time.Second
// Default receipt timeout in Conn.Disconnect function
const DefaultDisconnectReceiptTimeout = 30 * time.Second
// Reply-To header used for temporary queues/RPC with rabbit.
const ReplyToHeader = "reply-to"
// A Conn is a connection to a STOMP server. Create a Conn using either
// the Dial or Connect function.
type Conn struct {
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
disconnectReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
}
type writeRequest struct {
Frame *frame.Frame // frame to send
C chan *frame.Frame // response channel
}
// Dial creates a network connection to a STOMP server and performs
// the STOMP connect protocol sequence. The network endpoint of the
// STOMP server is specified by network and addr. STOMP protocol
// options can be specified in opts.
func Dial(network, addr string, opts ...func(*Conn) error) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
host, _, err := net.SplitHostPort(c.RemoteAddr().String())
if err != nil {
c.Close()
return nil, err
}
// Add option to set host and make it the first option in list,
// so that if host has been explicitly specified it will override.
opts = append([]func(*Conn) error{ConnOpt.Host(host)}, opts...)
return Connect(c, opts...)
}
// Connect creates a STOMP connection and performs the STOMP connect
// protocol sequence. The connection to the STOMP server has already
// been created by the program. The opts parameter provides the
// opportunity to specify STOMP protocol options.
func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error) {
reader := frame.NewReader(conn)
writer := frame.NewWriter(conn)
c := &Conn{
conn: conn,
closeMutex: &sync.Mutex{},
}
options, err := newConnOptions(c, opts)
if err != nil {
return nil, err
}
c.log = options.Logger
if options.ReadBufferSize > 0 {
reader = frame.NewReaderSize(conn, options.ReadBufferSize)
}
if options.WriteBufferSize > 0 {
writer = frame.NewWriterSize(conn, options.ReadBufferSize)
}
readChannelCapacity := 20
writeChannelCapacity := 20
if options.ReadChannelCapacity > 0 {
readChannelCapacity = options.ReadChannelCapacity
}
if options.WriteChannelCapacity > 0 {
writeChannelCapacity = options.WriteChannelCapacity
}
c.hbGracePeriodMultiplier = options.HeartBeatGracePeriodMultiplier
c.readCh = make(chan *frame.Frame, readChannelCapacity)
c.writeCh = make(chan writeRequest, writeChannelCapacity)
if options.Host == "" {
// host not specified yet, attempt to get from net.Conn if possible
if connection, ok := conn.(net.Conn); ok {
host, _, err := net.SplitHostPort(connection.RemoteAddr().String())
if err == nil {
options.Host = host
}
}
// if host is still blank, use default
if options.Host == "" {
options.Host = "default"
}
}
connectFrame, err := options.NewFrame()
if err != nil {
return nil, err
}
err = writer.Write(connectFrame)
if err != nil {
return nil, err
}
response, err := reader.Read()
if err != nil {
return nil, err
}
if response == nil {
return nil, errors.New("unexpected empty frame")
}
if response.Command != frame.CONNECTED {
return nil, newError(response)
}
c.server = response.Header.Get(frame.Server)
c.session = response.Header.Get(frame.Session)
if versionString := response.Header.Get(frame.Version); versionString != "" {
version := Version(versionString)
if err = version.CheckSupported(); err != nil {
return nil, Error{
Message: err.Error(),
Frame: response,
}
}
c.version = version
} else {
// no version in the response, so assume version 1.0
c.version = V10
}
if heartBeat, ok := response.Header.Contains(frame.HeartBeat); ok {
readTimeout, writeTimeout, err := frame.ParseHeartBeat(heartBeat)
if err != nil {
return nil, Error{
Message: err.Error(),
Frame: response,
}
}
c.readTimeout = readTimeout
c.writeTimeout = writeTimeout
if c.readTimeout > 0 {
// Add time to the read timeout to account for time
// delay in other station transmitting timeout
c.readTimeout += options.HeartBeatError
}
if c.writeTimeout > options.HeartBeatError {
// Reduce time from the write timeout to account
// for time delay in transmitting to the other station
c.writeTimeout -= options.HeartBeatError
}
}
c.msgSendTimeout = options.MsgSendTimeout
c.rcvReceiptTimeout = options.RcvReceiptTimeout
c.disconnectReceiptTimeout = options.DisconnectReceiptTimeout
if options.ResponseHeadersCallback != nil {
options.ResponseHeadersCallback(response.Header)
}
go readLoop(c, reader)
go processLoop(c, writer)
return c, nil
}
// Version returns the version of the STOMP protocol that
// is being used to communicate with the STOMP server. This
// version is negotiated with the server during the connect sequence.
func (c *Conn) Version() Version {
return c.version
}
// Session returns the session identifier, which can be
// returned by the STOMP server during the connect sequence.
// If the STOMP server does not return a session header entry,
// this value will be a blank string.
func (c *Conn) Session() string {
return c.session
}
// Server returns the STOMP server identification, which can
// be returned by the STOMP server during the connect sequence.
// If the STOMP server does not return a server header entry,
// this value will be a blank string.
func (c *Conn) Server() string {
return c.server
}
// readLoop is a goroutine that reads frames from the
// reader and places them onto a channel for processing
// by the processLoop goroutine
func readLoop(c *Conn, reader *frame.Reader) {
for {
f, err := reader.Read()
if err != nil {
close(c.readCh)
return
}
c.readCh <- f
}
}
// processLoop is a goroutine that handles io with
// the server.
func processLoop(c *Conn, writer *frame.Writer) {
channels := make(map[string]chan *frame.Frame)
var readTimeoutChannel <-chan time.Time
var readTimer *time.Timer
var writeTimeoutChannel <-chan time.Time
var writeTimer *time.Timer
defer c.MustDisconnect()
for {
if c.readTimeout > 0 && readTimer == nil {
readTimer = time.NewTimer(time.Duration(float64(c.readTimeout) * c.hbGracePeriodMultiplier))
readTimeoutChannel = readTimer.C
}
if c.writeTimeout > 0 && writeTimer == nil {
writeTimer = time.NewTimer(c.writeTimeout)
writeTimeoutChannel = writeTimer.C
}
select {
case <-readTimeoutChannel:
// read timeout, close the connection
err := newErrorMessage("read timeout")
sendError(channels, err)
return
case <-writeTimeoutChannel:
// write timeout, send a heart-beat frame
err := writer.Write(nil)
if err != nil {
sendError(channels, err)
return
}
writeTimer = nil
writeTimeoutChannel = nil
case f, ok := <-c.readCh:
// stop the read timer
if readTimer != nil {
readTimer.Stop()
readTimer = nil
readTimeoutChannel = nil
}
if !ok {
err := newErrorMessage("connection closed")
sendError(channels, err)
return
}
if f == nil {
// heart-beat received
continue
}
switch f.Command {
case frame.RECEIPT:
if id, ok := f.Header.Contains(frame.ReceiptId); ok {
if ch, ok := channels[id]; ok {
ch <- f
delete(channels, id)
close(ch)
}
} else {
err := &Error{Message: "missing receipt-id", Frame: f}
sendError(channels, err)
return
}
case frame.ERROR:
c.log.Error("received ERROR; Closing underlying connection")
for _, ch := range channels {
ch <- f
close(ch)
}
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
c.closed = true
c.conn.Close()
return
case frame.MESSAGE:
if id, ok := f.Header.Contains(frame.Subscription); ok {
if ch, ok := channels[id]; ok {
ch <- f
} else {
c.log.Infof("ignored MESSAGE for subscription: %s", id)
}
}
}
case req, ok := <-c.writeCh:
// stop the write timeout
if writeTimer != nil {
writeTimer.Stop()
writeTimer = nil
writeTimeoutChannel = nil
}
if !ok {
sendError(channels, errors.New("write channel closed"))
return
}
if req.C != nil {
if receipt, ok := req.Frame.Header.Contains(frame.Receipt); ok {
// remember the channel for this receipt
channels[receipt] = req.C
}
}
// default is to always send a frame.
var sendFrame = true
switch req.Frame.Command {
case frame.SUBSCRIBE:
id, _ := req.Frame.Header.Contains(frame.Id)
channels[id] = req.C
// if using a temp queue, map that destination as a known channel
// however, don't send the frame, it's most likely an invalid destination
// on the broker.
if replyTo, ok := req.Frame.Header.Contains(ReplyToHeader); ok {
channels[replyTo] = req.C
sendFrame = false
}
case frame.UNSUBSCRIBE:
id, _ := req.Frame.Header.Contains(frame.Id)
// is this trying to be too clever -- add a receipt
// header so that when the server responds with a
// RECEIPT frame, the corresponding channel will be closed
req.Frame.Header.Set(frame.Receipt, id)
}
// frame to send, if enabled
if sendFrame {
err := writer.Write(req.Frame)
if err != nil {
sendError(channels, err)
return
}
}
}
}
}
// Send an error to all receipt channels.
func sendError(m map[string]chan *frame.Frame, err error) {
frame := frame.New(frame.ERROR, frame.Message, err.Error())
for _, ch := range m {
ch <- frame
}
}
// Disconnect will disconnect from the STOMP server. This function
// follows the STOMP standard's recommended protocol for graceful
// disconnection: it sends a DISCONNECT frame with a receipt header
// element. Once the RECEIPT frame has been received, the connection
// with the STOMP server is closed and any further attempt to write
// to the server will fail.
func (c *Conn) Disconnect() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.closed {
return nil
}
ch := make(chan *frame.Frame)
c.writeCh <- writeRequest{
Frame: frame.New(frame.DISCONNECT, frame.Receipt, allocateId()),
C: ch,
}
err := readReceiptWithTimeout(ch, c.disconnectReceiptTimeout, ErrDisconnectReceiptTimeout)
if err == nil {
c.closed = true
return c.conn.Close()
}
if err == ErrDisconnectReceiptTimeout {
c.closed = true
_ = c.conn.Close()
}
return err
}
// MustDisconnect will disconnect 'ungracefully' from the STOMP server.
// This method should be used only as last resort when there are fatal
// network errors that prevent to do a proper disconnect from the server.
func (c *Conn) MustDisconnect() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.closed {
return nil
}
// just close writeCh
close(c.writeCh)
c.closed = true
return c.conn.Close()
}
// Send sends a message to the STOMP server, which in turn sends the message to the specified destination.
// If the STOMP server fails to receive the message for any reason, the connection will close.
//
// The content type should be specified, according to the STOMP specification, but if contentType is an empty
// string, the message will be delivered without a content-type header entry. The body array contains the
// message body, and its content should be consistent with the specified content type.
//
// Any number of options can be specified in opts. See the examples for usage. Options include whether
// to receive a RECEIPT, should the content-length be suppressed, and sending custom header entries.
func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(*frame.Frame) error) error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.closed {
return ErrAlreadyClosed
}
f, err := createSendFrame(destination, contentType, body, opts)
if err != nil {
return err
}
if _, ok := f.Header.Contains(frame.Receipt); ok {
// receipt required
request := writeRequest{
Frame: f,
C: make(chan *frame.Frame),
}
err := sendDataToWriteChWithTimeout(c.writeCh, request, c.msgSendTimeout)
if err != nil {
return err
}
err = readReceiptWithTimeout(request.C, c.rcvReceiptTimeout, ErrMsgReceiptTimeout)
if err != nil {
return err
}
} else {
// no receipt required
request := writeRequest{Frame: f}
err := sendDataToWriteChWithTimeout(c.writeCh, request, c.msgSendTimeout)
if err != nil {
return err
}
}
return nil
}
func readReceiptWithTimeout(responseChan chan *frame.Frame, timeout time.Duration, timeoutErr error) error {
var timeoutChan <-chan time.Time
if timeout > 0 {
timeoutChan = time.After(timeout)
}
select {
case <-timeoutChan:
return timeoutErr
case response := <-responseChan:
if response.Command != frame.RECEIPT {
return newError(response)
}
return nil
}
}
func sendDataToWriteChWithTimeout(ch chan writeRequest, request writeRequest, timeout time.Duration) error {
if timeout <= 0 {
ch <- request
return nil
}
timer := time.NewTimer(timeout)
select {
case <-timer.C:
return ErrMsgSendTimeout
case ch <- request:
timer.Stop()
return nil
}
}
func createSendFrame(destination, contentType string, body []byte, opts []func(*frame.Frame) error) (*frame.Frame, error) {
// Set the content-length before the options, because this provides
// an opportunity to remove content-length.
f := frame.New(frame.SEND, frame.ContentLength, strconv.Itoa(len(body)))
f.Body = body
f.Header.Set(frame.Destination, destination)
if contentType != "" {
f.Header.Set(frame.ContentType, contentType)
}
for _, opt := range opts {
if opt == nil {
continue
}
if err := opt(f); err != nil {
return nil, err
}
}
return f, nil
}
func (c *Conn) sendFrame(f *frame.Frame) error {
// Lock our mutex, but don't close it via defer
// If the frame requests a receipt then we want to release the lock before
// we block on the response, otherwise we can end up deadlocking
c.closeMutex.Lock()
if c.closed {
c.closeMutex.Unlock()
c.conn.Close()
return ErrClosedUnexpectedly
}
if _, ok := f.Header.Contains(frame.Receipt); ok {
// receipt required
request := writeRequest{
Frame: f,
C: make(chan *frame.Frame),
}
c.writeCh <- request
// Now that we've written to the writeCh channel we can release the
// close mutex while we wait for our response
c.closeMutex.Unlock()
var response *frame.Frame
if c.writeTimeout > 0 {
select {
case response, ok = <-request.C:
case <-time.After(c.writeTimeout):
ok = false
}
} else {
response, ok = <-request.C
}
if ok {
if response.Command != frame.RECEIPT {
return newError(response)
}
} else {
return ErrClosedUnexpectedly
}
} else {
// no receipt required
request := writeRequest{Frame: f}
c.writeCh <- request
// Unlock the mutex now that we're written to the write channel
c.closeMutex.Unlock()
}
return nil
}
// Subscribe creates a subscription on the STOMP server.
// The subscription has a destination, and messages sent to that destination
// will be received by this subscription. A subscription has a channel
// on which the calling program can receive messages.
func (c *Conn) Subscribe(destination string, ack AckMode, opts ...func(*frame.Frame) error) (*Subscription, error) {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
if c.closed {
c.conn.Close()
return nil, ErrClosedUnexpectedly
}
ch := make(chan *frame.Frame)
subscribeFrame := frame.New(frame.SUBSCRIBE,
frame.Destination, destination,
frame.Ack, ack.String())
for _, opt := range opts {
if opt == nil {
continue
}
err := opt(subscribeFrame)
if err != nil {
return nil, err
}
}
// If the option functions have not specified the "id" header entry,
// create one.
id, ok := subscribeFrame.Header.Contains(frame.Id)
if !ok {
id = allocateId()
subscribeFrame.Header.Add(frame.Id, id)
}
request := writeRequest{
Frame: subscribeFrame,
C: ch,
}
closeMutex := &sync.Mutex{}
sub := &Subscription{
id: id,
destination: destination,
conn: c,
ackMode: ack,
C: make(chan *Message, 16),
closeMutex: closeMutex,
closeCond: sync.NewCond(closeMutex),
}
go sub.readLoop(ch)
// TODO is this safe? There is no check if writeCh is actually open.
c.writeCh <- request
return sub, nil
}
// TODO check further for race conditions
// Ack acknowledges a message received from the STOMP server.
// If the message was received on a subscription with AckMode == AckAuto,
// then no operation is performed.
func (c *Conn) Ack(m *Message) error {
f, err := c.createAckNackFrame(m, true)
if err != nil {
return err
}
if f != nil {
return c.sendFrame(f)
}
return nil
}
// Nack indicates to the server that a message was not received
// by the client. Returns an error if the STOMP version does not
// support the NACK message.
func (c *Conn) Nack(m *Message) error {
f, err := c.createAckNackFrame(m, false)
if err != nil {
return err
}
if f != nil {
return c.sendFrame(f)
}
return nil
}
// Begin is used to start a transaction. Transactions apply to sending
// and acknowledging. Any messages sent or acknowledged during a transaction
// will be processed atomically by the STOMP server based on the transaction.
func (c *Conn) Begin() *Transaction {
t, _ := c.BeginWithError()
return t
}
// BeginWithError is used to start a transaction, but also returns the error
// (if any) from sending the frame to start the transaction.
func (c *Conn) BeginWithError() (*Transaction, error) {
id := allocateId()
f := frame.New(frame.BEGIN, frame.Transaction, id)
err := c.sendFrame(f)
return &Transaction{id: id, conn: c}, err
}
// Create an ACK or NACK frame. Complicated by version incompatibilities.
func (c *Conn) createAckNackFrame(msg *Message, ack bool) (*frame.Frame, error) {
if !ack && !c.version.SupportsNack() {
return nil, ErrNackNotSupported
}
if msg.Header == nil || msg.Subscription == nil || msg.Conn == nil {
return nil, ErrNotReceivedMessage
}
if msg.Subscription.AckMode() == AckAuto {
if ack {
// not much point sending an ACK to an auto subscription
return nil, nil
} else {
// sending a NACK for an ack:auto subscription makes no
// sense
return nil, ErrCannotNackAutoSub
}
}
var f *frame.Frame
if ack {
f = frame.New(frame.ACK)
} else {
f = frame.New(frame.NACK)
}
switch c.version {
case V10, V11:
f.Header.Add(frame.Subscription, msg.Subscription.Id())
if messageId, ok := msg.Header.Contains(frame.MessageId); ok {
f.Header.Add(frame.MessageId, messageId)
} else {
return nil, missingHeader(frame.MessageId)
}
case V12:
// message frame contains ack header
if ack, ok := msg.Header.Contains(frame.Ack); ok {
// ack frame should reference it as id
f.Header.Add(frame.Id, ack)
} else {
return nil, missingHeader(frame.Ack)
}
}
return f, nil
}