diff --git a/backend/services/stomp/AUTHORS.md b/backend/services/stomp/AUTHORS.md new file mode 100644 index 0000000..59921f3 --- /dev/null +++ b/backend/services/stomp/AUTHORS.md @@ -0,0 +1,20 @@ +## Authors + +* John Jeffery +* Hiram Jerónimo Pérez +* Alessandro Siragusa +* DaytonG +* Erik Benoist +* Evan Borgstrom +* Fernando Ribeiro +* Fredrik Rubensson +* Laurent Luce +* Oliver, Jonathan +* Paul P. Komkoff +* Raphael Tiersch +* Tom Lee +* Tony Le +* Voronkov Artem +* Whit Marbut +* hanjm +* yang.zhang4 diff --git a/backend/services/stomp/LICENSE.txt b/backend/services/stomp/LICENSE.txt new file mode 100644 index 0000000..343d70c --- /dev/null +++ b/backend/services/stomp/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2012 The go-stomp authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/backend/services/stomp/README.md b/backend/services/stomp/README.md new file mode 100644 index 0000000..c1c37b7 --- /dev/null +++ b/backend/services/stomp/README.md @@ -0,0 +1,44 @@ +This STOMP implementation was forked from https://github.com/go-stomp/stomp and we customized it to accomplish with TR-369 protocol + +# stomp + +Go language implementation of a STOMP client library. + +![Build Status](https://github.com/go-stomp/stomp/actions/workflows/test.yml/badge.svg?branch=master) +[![Go Reference](https://pkg.go.dev/badge/github.com/go-stomp/stomp/v3.svg)](https://pkg.go.dev/github.com/go-stomp/stomp/v3) + +Features: + +* Supports STOMP Specifications Versions 1.0, 1.1, 1.2 (https://stomp.github.io/) +* Protocol negotiation to select the latest mutually supported protocol +* Heart beating for testing the underlying network connection +* Tested against RabbitMQ v3.0.1 + +## Usage Instructions + +``` +go get github.com/go-stomp/stomp/v3 +``` + +For API documentation, see https://pkg.go.dev/github.com/go-stomp/stomp/v3 + + +Breaking changes between this previous version and the current version are +documented in [breaking_changes.md](breaking_changes.md). + + +## License +Copyright 2012 - Present The go-stomp authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + diff --git a/backend/services/stomp/ack.go b/backend/services/stomp/ack.go new file mode 100644 index 0000000..ede59c0 --- /dev/null +++ b/backend/services/stomp/ack.go @@ -0,0 +1,48 @@ +package stomp + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// The AckMode type is an enumeration of the acknowledgement modes for a +// STOMP subscription. +type AckMode int + +// String returns the string representation of the AckMode value. +func (a AckMode) String() string { + switch a { + case AckAuto: + return frame.AckAuto + case AckClient: + return frame.AckClient + case AckClientIndividual: + return frame.AckClientIndividual + } + panic("invalid AckMode value") +} + +// ShouldAck returns true if this AckMode is an acknowledgement +// mode which requires acknowledgement. Returns true for all values +// except AckAuto, which returns false. +func (a AckMode) ShouldAck() bool { + switch a { + case AckAuto: + return false + case AckClient, AckClientIndividual: + return true + } + panic("invalid AckMode value") +} + +const ( + // No acknowledgement is required, the server assumes that the client + // received the message. + AckAuto AckMode = iota + + // Client acknowledges messages. When a client acknowledges a message, + // any previously received messages are also acknowledged. + AckClient + + // Client acknowledges message. Each message is acknowledged individually. + AckClientIndividual +) diff --git a/backend/services/stomp/breaking_changes.md b/backend/services/stomp/breaking_changes.md new file mode 100644 index 0000000..b60209f --- /dev/null +++ b/backend/services/stomp/breaking_changes.md @@ -0,0 +1,80 @@ +# Breaking Changes + +This document provides a list of breaking changes since the V1 release +of the stomp client library. + +## v2 and v3 + +### Module support + +Version 2 was released before module support was present in golang, and changes were tagged wit that version. +Therefore we had to update again the import path. + +The API it's stable the only breaking change is the import path. + +Version 3: +```go +import ( + "github.com/go-stomp/stomp/v3" +) +``` + +## v1 and v2 + +### 1. No longer using gopkg.in + +Version 1 of the library used Gustavo Niemeyer's `gopkg.in` facility for versioning Go libraries. +For a number of reasons, the `stomp` library no longer uses this facility. For this reason the +import path has changed. + +Version 1: +```go +import ( + "gopkg.in/stomp.v1" +) +``` + +Version 2: +```go +import ( + "github.com/go-stomp/stomp" +) +``` + +### 2. Frame types moved to frame package + +Version 1 of the library included a number of types to do with STOMP frames in the `stomp` +package, and the `frame` package consisted of just a few constant definitions. + +It was decided to move the following types out of the `stomp` package and into the `frame` package: + +* `stomp.Frame` -> `frame.Frame` +* `stomp.Header` -> `frame.Header` +* `stomp.Reader` -> `frame.Reader` +* `stomp.Writer` -> `frame.Writer` + +This change was considered worthwhile for the following reasons: + +* This change reduces the surface area of the `stomp` package and makes it easier to learn. +* Ideally, users of the `stomp` package do not need to directly reference the items in the `frame` +package, and the types moved are not needed in normal usage of the `stomp` package. + +### 3. Use of functional options + +Version 2 of the stomp library makes use of functional options to provide a clean, flexible way +of specifying options in the following API calls: + +* [Dial()](http://godoc.org/github.com/go-stomp/stomp#Dial) +* [Connect()](http://godoc.org/github.com/go-stomp/stomp#Connect) +* [Conn.Send()](http://godoc.org/github.com/go-stomp/stomp#Conn.Send) +* [Transaction.Send()](http://godoc.org/github.com/go-stomp/stomp#Transaction.Send) +* [Conn.Subscribe()](http://godoc.org/github.com/go-stomp/stomp#Conn.Subscribe) + +The idea for this comes from Dave Cheney's very excellent blog post, +[Functional Options for Friendly APIs](http://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis). + +While these new APIs are a definite improvement, they do introduce breaking changes with Version 1. + + + + diff --git a/backend/services/stomp/cmd/main.go b/backend/services/stomp/cmd/main.go index f634f86..eb12e48 100644 --- a/backend/services/stomp/cmd/main.go +++ b/backend/services/stomp/cmd/main.go @@ -5,7 +5,7 @@ import ( "net" "os" - stomp_server "github.com/go-stomp/stomp/server" + "github.com/go-stomp/stomp/v3/server" "github.com/joho/godotenv" ) @@ -47,14 +47,14 @@ func main() { Passwd: os.Getenv("STOMP_PASSWORD"), } - l, err := net.Listen("tcp", stomp_server.DefaultAddr) + l, err := net.Listen("tcp", server.DefaultAddr) if err != nil { log.Println("Error to open tcp port: ", err) } - s := stomp_server.Server{ - Addr: stomp_server.DefaultAddr, - HeartBeat: stomp_server.DefaultHeartBeat, + s := server.Server{ + Addr: server.DefaultAddr, + HeartBeat: server.DefaultHeartBeat, Authenticator: &creds, } diff --git a/backend/services/stomp/conn.go b/backend/services/stomp/conn.go new file mode 100644 index 0000000..a1b4bca --- /dev/null +++ b/backend/services/stomp/conn.go @@ -0,0 +1,774 @@ +package stomp + +import ( + "errors" + "io" + "net" + "strconv" + "sync" + "time" + + "github.com/go-stomp/stomp/v3/frame" +) + +// Default time span to add to read/write heart-beat timeouts +// to avoid premature disconnections due to network latency. +const DefaultHeartBeatError = 5 * time.Second + +// Default send timeout in Conn.Send function +const DefaultMsgSendTimeout = 10 * time.Second + +// Default receipt timeout in Conn.Send function +const DefaultRcvReceiptTimeout = 30 * time.Second + +// Default receipt timeout in Conn.Disconnect function +const DefaultDisconnectReceiptTimeout = 30 * time.Second + +// Reply-To header used for temporary queues/RPC with rabbit. +const ReplyToHeader = "reply-to" + +// A Conn is a connection to a STOMP server. Create a Conn using either +// the Dial or Connect function. +type Conn struct { + conn io.ReadWriteCloser + readCh chan *frame.Frame + writeCh chan writeRequest + version Version + session string + server string + readTimeout time.Duration + writeTimeout time.Duration + msgSendTimeout time.Duration + rcvReceiptTimeout time.Duration + disconnectReceiptTimeout time.Duration + hbGracePeriodMultiplier float64 + closed bool + closeMutex *sync.Mutex + options *connOptions + log Logger +} + +type writeRequest struct { + Frame *frame.Frame // frame to send + C chan *frame.Frame // response channel +} + +// Dial creates a network connection to a STOMP server and performs +// the STOMP connect protocol sequence. The network endpoint of the +// STOMP server is specified by network and addr. STOMP protocol +// options can be specified in opts. +func Dial(network, addr string, opts ...func(*Conn) error) (*Conn, error) { + c, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + + host, _, err := net.SplitHostPort(c.RemoteAddr().String()) + if err != nil { + c.Close() + return nil, err + } + + // Add option to set host and make it the first option in list, + // so that if host has been explicitly specified it will override. + opts = append([]func(*Conn) error{ConnOpt.Host(host)}, opts...) + + return Connect(c, opts...) +} + +// Connect creates a STOMP connection and performs the STOMP connect +// protocol sequence. The connection to the STOMP server has already +// been created by the program. The opts parameter provides the +// opportunity to specify STOMP protocol options. +func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error) { + reader := frame.NewReader(conn) + writer := frame.NewWriter(conn) + + c := &Conn{ + conn: conn, + closeMutex: &sync.Mutex{}, + } + + options, err := newConnOptions(c, opts) + if err != nil { + return nil, err + } + + c.log = options.Logger + + if options.ReadBufferSize > 0 { + reader = frame.NewReaderSize(conn, options.ReadBufferSize) + } + + if options.WriteBufferSize > 0 { + writer = frame.NewWriterSize(conn, options.ReadBufferSize) + } + + readChannelCapacity := 20 + writeChannelCapacity := 20 + + if options.ReadChannelCapacity > 0 { + readChannelCapacity = options.ReadChannelCapacity + } + + if options.WriteChannelCapacity > 0 { + writeChannelCapacity = options.WriteChannelCapacity + } + + c.hbGracePeriodMultiplier = options.HeartBeatGracePeriodMultiplier + + c.readCh = make(chan *frame.Frame, readChannelCapacity) + c.writeCh = make(chan writeRequest, writeChannelCapacity) + + if options.Host == "" { + // host not specified yet, attempt to get from net.Conn if possible + if connection, ok := conn.(net.Conn); ok { + host, _, err := net.SplitHostPort(connection.RemoteAddr().String()) + if err == nil { + options.Host = host + } + } + // if host is still blank, use default + if options.Host == "" { + options.Host = "default" + } + } + + connectFrame, err := options.NewFrame() + if err != nil { + return nil, err + } + + err = writer.Write(connectFrame) + if err != nil { + return nil, err + } + + response, err := reader.Read() + if err != nil { + return nil, err + } + if response == nil { + return nil, errors.New("unexpected empty frame") + } + + if response.Command != frame.CONNECTED { + return nil, newError(response) + } + + c.server = response.Header.Get(frame.Server) + c.session = response.Header.Get(frame.Session) + + if versionString := response.Header.Get(frame.Version); versionString != "" { + version := Version(versionString) + if err = version.CheckSupported(); err != nil { + return nil, Error{ + Message: err.Error(), + Frame: response, + } + } + c.version = version + } else { + // no version in the response, so assume version 1.0 + c.version = V10 + } + + if heartBeat, ok := response.Header.Contains(frame.HeartBeat); ok { + readTimeout, writeTimeout, err := frame.ParseHeartBeat(heartBeat) + if err != nil { + return nil, Error{ + Message: err.Error(), + Frame: response, + } + } + + c.readTimeout = readTimeout + c.writeTimeout = writeTimeout + + if c.readTimeout > 0 { + // Add time to the read timeout to account for time + // delay in other station transmitting timeout + c.readTimeout += options.HeartBeatError + } + if c.writeTimeout > options.HeartBeatError { + // Reduce time from the write timeout to account + // for time delay in transmitting to the other station + c.writeTimeout -= options.HeartBeatError + } + } + + c.msgSendTimeout = options.MsgSendTimeout + c.rcvReceiptTimeout = options.RcvReceiptTimeout + c.disconnectReceiptTimeout = options.DisconnectReceiptTimeout + + if options.ResponseHeadersCallback != nil { + options.ResponseHeadersCallback(response.Header) + } + + go readLoop(c, reader) + go processLoop(c, writer) + + return c, nil +} + +// Version returns the version of the STOMP protocol that +// is being used to communicate with the STOMP server. This +// version is negotiated with the server during the connect sequence. +func (c *Conn) Version() Version { + return c.version +} + +// Session returns the session identifier, which can be +// returned by the STOMP server during the connect sequence. +// If the STOMP server does not return a session header entry, +// this value will be a blank string. +func (c *Conn) Session() string { + return c.session +} + +// Server returns the STOMP server identification, which can +// be returned by the STOMP server during the connect sequence. +// If the STOMP server does not return a server header entry, +// this value will be a blank string. +func (c *Conn) Server() string { + return c.server +} + +// readLoop is a goroutine that reads frames from the +// reader and places them onto a channel for processing +// by the processLoop goroutine +func readLoop(c *Conn, reader *frame.Reader) { + for { + f, err := reader.Read() + if err != nil { + close(c.readCh) + return + } + c.readCh <- f + } +} + +// processLoop is a goroutine that handles io with +// the server. +func processLoop(c *Conn, writer *frame.Writer) { + channels := make(map[string]chan *frame.Frame) + + var readTimeoutChannel <-chan time.Time + var readTimer *time.Timer + var writeTimeoutChannel <-chan time.Time + var writeTimer *time.Timer + + defer c.MustDisconnect() + + for { + if c.readTimeout > 0 && readTimer == nil { + readTimer = time.NewTimer(time.Duration(float64(c.readTimeout) * c.hbGracePeriodMultiplier)) + readTimeoutChannel = readTimer.C + } + if c.writeTimeout > 0 && writeTimer == nil { + writeTimer = time.NewTimer(c.writeTimeout) + writeTimeoutChannel = writeTimer.C + } + + select { + case <-readTimeoutChannel: + // read timeout, close the connection + err := newErrorMessage("read timeout") + sendError(channels, err) + return + + case <-writeTimeoutChannel: + // write timeout, send a heart-beat frame + err := writer.Write(nil) + if err != nil { + sendError(channels, err) + return + } + writeTimer = nil + writeTimeoutChannel = nil + + case f, ok := <-c.readCh: + // stop the read timer + if readTimer != nil { + readTimer.Stop() + readTimer = nil + readTimeoutChannel = nil + } + + if !ok { + err := newErrorMessage("connection closed") + sendError(channels, err) + return + } + + if f == nil { + // heart-beat received + continue + } + + switch f.Command { + case frame.RECEIPT: + if id, ok := f.Header.Contains(frame.ReceiptId); ok { + if ch, ok := channels[id]; ok { + ch <- f + delete(channels, id) + close(ch) + } + } else { + err := &Error{Message: "missing receipt-id", Frame: f} + sendError(channels, err) + return + } + + case frame.ERROR: + c.log.Error("received ERROR; Closing underlying connection") + for _, ch := range channels { + ch <- f + close(ch) + } + + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + c.closed = true + c.conn.Close() + + return + + case frame.MESSAGE: + if id, ok := f.Header.Contains(frame.Subscription); ok { + if ch, ok := channels[id]; ok { + ch <- f + } else { + c.log.Infof("ignored MESSAGE for subscription: %s", id) + } + } + } + + case req, ok := <-c.writeCh: + // stop the write timeout + if writeTimer != nil { + writeTimer.Stop() + writeTimer = nil + writeTimeoutChannel = nil + } + if !ok { + sendError(channels, errors.New("write channel closed")) + return + } + if req.C != nil { + if receipt, ok := req.Frame.Header.Contains(frame.Receipt); ok { + // remember the channel for this receipt + channels[receipt] = req.C + } + } + + // default is to always send a frame. + var sendFrame = true + + switch req.Frame.Command { + case frame.SUBSCRIBE: + id, _ := req.Frame.Header.Contains(frame.Id) + channels[id] = req.C + + // if using a temp queue, map that destination as a known channel + // however, don't send the frame, it's most likely an invalid destination + // on the broker. + if replyTo, ok := req.Frame.Header.Contains(ReplyToHeader); ok { + channels[replyTo] = req.C + sendFrame = false + } + + case frame.UNSUBSCRIBE: + id, _ := req.Frame.Header.Contains(frame.Id) + // is this trying to be too clever -- add a receipt + // header so that when the server responds with a + // RECEIPT frame, the corresponding channel will be closed + req.Frame.Header.Set(frame.Receipt, id) + + } + + // frame to send, if enabled + if sendFrame { + err := writer.Write(req.Frame) + if err != nil { + sendError(channels, err) + return + } + } + } + } +} + +// Send an error to all receipt channels. +func sendError(m map[string]chan *frame.Frame, err error) { + frame := frame.New(frame.ERROR, frame.Message, err.Error()) + for _, ch := range m { + ch <- frame + } +} + +// Disconnect will disconnect from the STOMP server. This function +// follows the STOMP standard's recommended protocol for graceful +// disconnection: it sends a DISCONNECT frame with a receipt header +// element. Once the RECEIPT frame has been received, the connection +// with the STOMP server is closed and any further attempt to write +// to the server will fail. +func (c *Conn) Disconnect() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + if c.closed { + return nil + } + + ch := make(chan *frame.Frame) + c.writeCh <- writeRequest{ + Frame: frame.New(frame.DISCONNECT, frame.Receipt, allocateId()), + C: ch, + } + + err := readReceiptWithTimeout(ch, c.disconnectReceiptTimeout, ErrDisconnectReceiptTimeout) + if err == nil { + c.closed = true + return c.conn.Close() + } + + if err == ErrDisconnectReceiptTimeout { + c.closed = true + _ = c.conn.Close() + } + + return err +} + +// MustDisconnect will disconnect 'ungracefully' from the STOMP server. +// This method should be used only as last resort when there are fatal +// network errors that prevent to do a proper disconnect from the server. +func (c *Conn) MustDisconnect() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + if c.closed { + return nil + } + + // just close writeCh + close(c.writeCh) + + c.closed = true + return c.conn.Close() +} + +// Send sends a message to the STOMP server, which in turn sends the message to the specified destination. +// If the STOMP server fails to receive the message for any reason, the connection will close. +// +// The content type should be specified, according to the STOMP specification, but if contentType is an empty +// string, the message will be delivered without a content-type header entry. The body array contains the +// message body, and its content should be consistent with the specified content type. +// +// Any number of options can be specified in opts. See the examples for usage. Options include whether +// to receive a RECEIPT, should the content-length be suppressed, and sending custom header entries. +func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(*frame.Frame) error) error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + if c.closed { + return ErrAlreadyClosed + } + + f, err := createSendFrame(destination, contentType, body, opts) + if err != nil { + return err + } + + if _, ok := f.Header.Contains(frame.Receipt); ok { + // receipt required + request := writeRequest{ + Frame: f, + C: make(chan *frame.Frame), + } + + err := sendDataToWriteChWithTimeout(c.writeCh, request, c.msgSendTimeout) + if err != nil { + return err + } + + err = readReceiptWithTimeout(request.C, c.rcvReceiptTimeout, ErrMsgReceiptTimeout) + if err != nil { + return err + } + } else { + // no receipt required + request := writeRequest{Frame: f} + + err := sendDataToWriteChWithTimeout(c.writeCh, request, c.msgSendTimeout) + if err != nil { + return err + } + } + + return nil +} + +func readReceiptWithTimeout(responseChan chan *frame.Frame, timeout time.Duration, timeoutErr error) error { + var timeoutChan <-chan time.Time + if timeout > 0 { + timeoutChan = time.After(timeout) + } + + select { + case <-timeoutChan: + return timeoutErr + case response := <-responseChan: + if response.Command != frame.RECEIPT { + return newError(response) + } + return nil + } +} + +func sendDataToWriteChWithTimeout(ch chan writeRequest, request writeRequest, timeout time.Duration) error { + if timeout <= 0 { + ch <- request + return nil + } + + timer := time.NewTimer(timeout) + select { + case <-timer.C: + return ErrMsgSendTimeout + case ch <- request: + timer.Stop() + return nil + } +} + +func createSendFrame(destination, contentType string, body []byte, opts []func(*frame.Frame) error) (*frame.Frame, error) { + // Set the content-length before the options, because this provides + // an opportunity to remove content-length. + f := frame.New(frame.SEND, frame.ContentLength, strconv.Itoa(len(body))) + f.Body = body + f.Header.Set(frame.Destination, destination) + if contentType != "" { + f.Header.Set(frame.ContentType, contentType) + } + + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt(f); err != nil { + return nil, err + } + } + + return f, nil +} + +func (c *Conn) sendFrame(f *frame.Frame) error { + // Lock our mutex, but don't close it via defer + // If the frame requests a receipt then we want to release the lock before + // we block on the response, otherwise we can end up deadlocking + c.closeMutex.Lock() + if c.closed { + c.closeMutex.Unlock() + c.conn.Close() + return ErrClosedUnexpectedly + } + + if _, ok := f.Header.Contains(frame.Receipt); ok { + // receipt required + request := writeRequest{ + Frame: f, + C: make(chan *frame.Frame), + } + + c.writeCh <- request + + // Now that we've written to the writeCh channel we can release the + // close mutex while we wait for our response + c.closeMutex.Unlock() + + var response *frame.Frame + + if c.writeTimeout > 0 { + select { + case response, ok = <-request.C: + case <-time.After(c.writeTimeout): + ok = false + } + } else { + response, ok = <-request.C + } + + if ok { + if response.Command != frame.RECEIPT { + return newError(response) + } + } else { + return ErrClosedUnexpectedly + } + } else { + // no receipt required + request := writeRequest{Frame: f} + c.writeCh <- request + + // Unlock the mutex now that we're written to the write channel + c.closeMutex.Unlock() + } + + return nil +} + +// Subscribe creates a subscription on the STOMP server. +// The subscription has a destination, and messages sent to that destination +// will be received by this subscription. A subscription has a channel +// on which the calling program can receive messages. +func (c *Conn) Subscribe(destination string, ack AckMode, opts ...func(*frame.Frame) error) (*Subscription, error) { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + if c.closed { + c.conn.Close() + return nil, ErrClosedUnexpectedly + } + + ch := make(chan *frame.Frame) + + subscribeFrame := frame.New(frame.SUBSCRIBE, + frame.Destination, destination, + frame.Ack, ack.String()) + + for _, opt := range opts { + if opt == nil { + continue + } + err := opt(subscribeFrame) + if err != nil { + return nil, err + } + } + + // If the option functions have not specified the "id" header entry, + // create one. + id, ok := subscribeFrame.Header.Contains(frame.Id) + if !ok { + id = allocateId() + subscribeFrame.Header.Add(frame.Id, id) + } + + request := writeRequest{ + Frame: subscribeFrame, + C: ch, + } + + closeMutex := &sync.Mutex{} + sub := &Subscription{ + id: id, + destination: destination, + conn: c, + ackMode: ack, + C: make(chan *Message, 16), + closeMutex: closeMutex, + closeCond: sync.NewCond(closeMutex), + } + go sub.readLoop(ch) + + // TODO is this safe? There is no check if writeCh is actually open. + c.writeCh <- request + return sub, nil +} + +// TODO check further for race conditions + +// Ack acknowledges a message received from the STOMP server. +// If the message was received on a subscription with AckMode == AckAuto, +// then no operation is performed. +func (c *Conn) Ack(m *Message) error { + f, err := c.createAckNackFrame(m, true) + if err != nil { + return err + } + + if f != nil { + return c.sendFrame(f) + } + return nil +} + +// Nack indicates to the server that a message was not received +// by the client. Returns an error if the STOMP version does not +// support the NACK message. +func (c *Conn) Nack(m *Message) error { + f, err := c.createAckNackFrame(m, false) + if err != nil { + return err + } + + if f != nil { + return c.sendFrame(f) + } + return nil +} + +// Begin is used to start a transaction. Transactions apply to sending +// and acknowledging. Any messages sent or acknowledged during a transaction +// will be processed atomically by the STOMP server based on the transaction. +func (c *Conn) Begin() *Transaction { + t, _ := c.BeginWithError() + return t +} + +// BeginWithError is used to start a transaction, but also returns the error +// (if any) from sending the frame to start the transaction. +func (c *Conn) BeginWithError() (*Transaction, error) { + id := allocateId() + f := frame.New(frame.BEGIN, frame.Transaction, id) + err := c.sendFrame(f) + return &Transaction{id: id, conn: c}, err +} + +// Create an ACK or NACK frame. Complicated by version incompatibilities. +func (c *Conn) createAckNackFrame(msg *Message, ack bool) (*frame.Frame, error) { + if !ack && !c.version.SupportsNack() { + return nil, ErrNackNotSupported + } + + if msg.Header == nil || msg.Subscription == nil || msg.Conn == nil { + return nil, ErrNotReceivedMessage + } + + if msg.Subscription.AckMode() == AckAuto { + if ack { + // not much point sending an ACK to an auto subscription + return nil, nil + } else { + // sending a NACK for an ack:auto subscription makes no + // sense + return nil, ErrCannotNackAutoSub + } + } + + var f *frame.Frame + if ack { + f = frame.New(frame.ACK) + } else { + f = frame.New(frame.NACK) + } + + switch c.version { + case V10, V11: + f.Header.Add(frame.Subscription, msg.Subscription.Id()) + if messageId, ok := msg.Header.Contains(frame.MessageId); ok { + f.Header.Add(frame.MessageId, messageId) + } else { + return nil, missingHeader(frame.MessageId) + } + case V12: + // message frame contains ack header + if ack, ok := msg.Header.Contains(frame.Ack); ok { + // ack frame should reference it as id + f.Header.Add(frame.Id, ack) + } else { + return nil, missingHeader(frame.Ack) + } + } + + return f, nil +} diff --git a/backend/services/stomp/conn_options.go b/backend/services/stomp/conn_options.go new file mode 100644 index 0000000..9daddf7 --- /dev/null +++ b/backend/services/stomp/conn_options.go @@ -0,0 +1,327 @@ +package stomp + +import ( + "fmt" + "strings" + "time" + + "github.com/go-stomp/stomp/v3/frame" + "github.com/go-stomp/stomp/v3/internal/log" +) + +// ConnOptions is an opaque structure used to collection options +// for connecting to the other server. +type connOptions struct { + FrameCommand string + Host string + ReadTimeout time.Duration + WriteTimeout time.Duration + HeartBeatError time.Duration + MsgSendTimeout time.Duration + RcvReceiptTimeout time.Duration + DisconnectReceiptTimeout time.Duration + HeartBeatGracePeriodMultiplier float64 + Login, Passcode string + AcceptVersions []string + Header *frame.Header + ReadChannelCapacity, WriteChannelCapacity int + ReadBufferSize, WriteBufferSize int + ResponseHeadersCallback func(*frame.Header) + Logger Logger +} + +func newConnOptions(conn *Conn, opts []func(*Conn) error) (*connOptions, error) { + co := &connOptions{ + FrameCommand: frame.CONNECT, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + HeartBeatGracePeriodMultiplier: 1.0, + HeartBeatError: DefaultHeartBeatError, + MsgSendTimeout: DefaultMsgSendTimeout, + RcvReceiptTimeout: DefaultRcvReceiptTimeout, + DisconnectReceiptTimeout: DefaultDisconnectReceiptTimeout, + Logger: log.StdLogger{}, + } + + // This is a slight of hand, attach the options to the Conn long + // enough to run the options functions and then detach again. + // The reason we do this is to allow for future options to be able + // to modify the Conn object itself, in case that becomes desirable. + conn.options = co + defer func() { conn.options = nil }() + + // compatibility with previous version: ignore nil options + for _, opt := range opts { + if opt != nil { + err := opt(conn) + if err != nil { + return nil, err + } + } + } + + if len(co.AcceptVersions) == 0 { + co.AcceptVersions = append(co.AcceptVersions, string(V10), string(V11), string(V12)) + } + + return co, nil +} + +func (co *connOptions) NewFrame() (*frame.Frame, error) { + f := frame.New(co.FrameCommand) + if co.Host != "" { + f.Header.Set(frame.Host, co.Host) + } + + // heart-beat + { + send := co.WriteTimeout / time.Millisecond + recv := co.ReadTimeout / time.Millisecond + f.Header.Set(frame.HeartBeat, fmt.Sprintf("%d,%d", send, recv)) + } + + // login, passcode + if co.Login != "" || co.Passcode != "" { + f.Header.Set(frame.Login, co.Login) + f.Header.Set(frame.Passcode, co.Passcode) + } + + // accept-version + f.Header.Set(frame.AcceptVersion, strings.Join(co.AcceptVersions, ",")) + + // custom header entries -- note that these do not override + // header values already set as they are added to the end of + // the header array + f.Header.AddHeader(co.Header) + + return f, nil +} + +// Options for connecting to the STOMP server. Used with the +// stomp.Dial and stomp.Connect functions, both of which have examples. +var ConnOpt struct { + // Login is a connect option that allows the calling program to + // specify the "login" and "passcode" values to send to the STOMP + // server. + Login func(login, passcode string) func(*Conn) error + + // Host is a connect option that allows the calling program to + // specify the value of the "host" header. + Host func(host string) func(*Conn) error + + // UseStomp is a connect option that specifies that the client + // should use the "STOMP" command instead of the "CONNECT" command. + // Note that using "STOMP" is only valid for STOMP version 1.1 and later. + UseStomp func(*Conn) error + + // AcceptVersoin is a connect option that allows the client to + // specify one or more versions of the STOMP protocol that the + // client program is prepared to accept. If this option is not + // specified, the client program will accept any of STOMP versions + // 1.0, 1.1 or 1.2. + AcceptVersion func(versions ...Version) func(*Conn) error + + // HeartBeat is a connect option that allows the client to specify + // the send and receive timeouts for the STOMP heartbeat negotiation mechanism. + // The sendTimeout parameter specifies the maximum amount of time + // between the client sending heartbeat notifications from the server. + // The recvTimeout paramter specifies the minimum amount of time between + // the client expecting to receive heartbeat notifications from the server. + // If not specified, this option defaults to one minute for both send and receive + // timeouts. + HeartBeat func(sendTimeout, recvTimeout time.Duration) func(*Conn) error + + // HeartBeatError is a connect option that will normally only be specified during + // testing. It specifies a short time duration that is larger than the amount of time + // that will take for a STOMP frame to be transmitted from one station to the other. + // When not specified, this value defaults to 5 seconds. This value is set to a much + // shorter time duration during unit testing. + HeartBeatError func(errorTimeout time.Duration) func(*Conn) error + + // MsgSendTimeout is a connect option that allows the client to specify + // the timeout for the Conn.Send function. + // The msgSendTimeout parameter specifies maximum blocking time for calling + // the Conn.Send function. + // If not specified, this option defaults to 10 seconds. + // Less than or equal to zero means infinite + MsgSendTimeout func(msgSendTimeout time.Duration) func(*Conn) error + + // RcvReceiptTimeout is a connect option that allows the client to specify + // how long to wait for a receipt in the Conn.Send function. This helps + // avoid deadlocks. If this is not specified, the default is 30 seconds. + RcvReceiptTimeout func(rcvReceiptTimeout time.Duration) func(*Conn) error + + // DisconnectReceiptTimeout is a connect option that allows the client to specify + // how long to wait for a receipt in the Conn.Disconnect function. This helps + // avoid deadlocks. If this is not specified, the default is 30 seconds. + DisconnectReceiptTimeout func(disconnectReceiptTimeout time.Duration) func(*Conn) error + + // HeartBeatGracePeriodMultiplier is used to calculate the effective read heart-beat timeout + // the broker will enforce for each client’s connection. The multiplier is applied to + // the read-timeout interval the client specifies in its CONNECT frame + HeartBeatGracePeriodMultiplier func(multiplier float64) func(*Conn) error + + // Header is a connect option that allows the client to specify a custom + // header entry in the STOMP frame. This connect option can be specified + // multiple times for multiple custom headers. + Header func(key, value string) func(*Conn) error + + // ReadChannelCapacity is the number of messages that can be on the read channel at the + // same time. A high number may affect memory usage while a too low number may lock the + // system up. Default is set to 20. + ReadChannelCapacity func(capacity int) func(*Conn) error + + // WriteChannelCapacity is the number of messages that can be on the write channel at the + // same time. A high number may affect memory usage while a too low number may lock the + // system up. Default is set to 20. + WriteChannelCapacity func(capacity int) func(*Conn) error + + // ReadBufferSize specifies number of bytes that can be used to read the message + // A high number may affect memory usage while a too low number may lock the + // system up. Default is set to 4096. + ReadBufferSize func(size int) func(*Conn) error + + // WriteBufferSize specifies number of bytes that can be used to write the message + // A high number may affect memory usage while a too low number may lock the + // system up. Default is set to 4096. + WriteBufferSize func(size int) func(*Conn) error + + // ResponseHeaders lets you provide a callback function to get the headers from the CONNECT response + ResponseHeaders func(func(*frame.Header)) func(*Conn) error + + // Logger lets you provide a callback function that sets the logger used by a connection + Logger func(logger Logger) func(*Conn) error +} + +func init() { + ConnOpt.Login = func(login, passcode string) func(*Conn) error { + return func(c *Conn) error { + c.options.Login = login + c.options.Passcode = passcode + return nil + } + } + + ConnOpt.Host = func(host string) func(*Conn) error { + return func(c *Conn) error { + c.options.Host = host + return nil + } + } + + ConnOpt.UseStomp = func(c *Conn) error { + c.options.FrameCommand = frame.STOMP + return nil + } + + ConnOpt.AcceptVersion = func(versions ...Version) func(*Conn) error { + return func(c *Conn) error { + for _, version := range versions { + if err := version.CheckSupported(); err != nil { + return err + } + c.options.AcceptVersions = append(c.options.AcceptVersions, string(version)) + } + return nil + } + } + + ConnOpt.HeartBeat = func(sendTimeout, recvTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.WriteTimeout = sendTimeout + c.options.ReadTimeout = recvTimeout + return nil + } + } + + ConnOpt.HeartBeatError = func(errorTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.HeartBeatError = errorTimeout + return nil + } + } + + ConnOpt.MsgSendTimeout = func(msgSendTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.MsgSendTimeout = msgSendTimeout + return nil + } + } + + ConnOpt.RcvReceiptTimeout = func(rcvReceiptTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.RcvReceiptTimeout = rcvReceiptTimeout + return nil + } + } + + ConnOpt.DisconnectReceiptTimeout = func(disconnectReceiptTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.DisconnectReceiptTimeout = disconnectReceiptTimeout + return nil + } + } + + ConnOpt.HeartBeatGracePeriodMultiplier = func(multiplier float64) func(*Conn) error { + return func(c *Conn) error { + c.options.HeartBeatGracePeriodMultiplier = multiplier + return nil + } + } + + ConnOpt.Header = func(key, value string) func(*Conn) error { + return func(c *Conn) error { + if c.options.Header == nil { + c.options.Header = frame.NewHeader(key, value) + } else { + c.options.Header.Add(key, value) + } + return nil + } + } + + ConnOpt.ReadChannelCapacity = func(capacity int) func(*Conn) error { + return func(c *Conn) error { + c.options.ReadChannelCapacity = capacity + return nil + } + } + + ConnOpt.WriteChannelCapacity = func(capacity int) func(*Conn) error { + return func(c *Conn) error { + c.options.WriteChannelCapacity = capacity + return nil + } + } + + ConnOpt.ReadBufferSize = func(size int) func(*Conn) error { + return func(c *Conn) error { + c.options.ReadBufferSize = size + return nil + } + } + + ConnOpt.WriteBufferSize = func(size int) func(*Conn) error { + return func(c *Conn) error { + c.options.WriteBufferSize = size + return nil + } + } + + ConnOpt.ResponseHeaders = func(callback func(*frame.Header)) func(*Conn) error { + return func(c *Conn) error { + c.options.ResponseHeadersCallback = callback + return nil + } + } + + ConnOpt.Logger = func(log Logger) func(*Conn) error { + return func(c *Conn) error { + if log != nil { + c.options.Logger = log + } + + return nil + } + } +} diff --git a/backend/services/stomp/conn_test.go b/backend/services/stomp/conn_test.go new file mode 100644 index 0000000..cbb5a9d --- /dev/null +++ b/backend/services/stomp/conn_test.go @@ -0,0 +1,786 @@ +package stomp + +import ( + "fmt" + "io" + "time" + + "github.com/go-stomp/stomp/v3/frame" + "github.com/go-stomp/stomp/v3/testutil" + + "github.com/golang/mock/gomock" + . "gopkg.in/check.v1" +) + +type fakeReaderWriter struct { + reader *frame.Reader + writer *frame.Writer + conn io.ReadWriteCloser +} + +func (rw *fakeReaderWriter) Read() (*frame.Frame, error) { + return rw.reader.Read() +} + +func (rw *fakeReaderWriter) Write(f *frame.Frame) error { + return rw.writer.Write(f) +} + +func (rw *fakeReaderWriter) Close() error { + return rw.conn.Close() +} + +func (s *StompSuite) Test_conn_option_set_logger(c *C) { + fc1, fc2 := testutil.NewFakeConn(c) + go func() { + + defer func() { + fc2.Close() + fc1.Close() + }() + + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + f2 := frame.New("CONNECTED") + err = writer.Write(f2) + c.Assert(err, IsNil) + }() + + ctrl := gomock.NewController(s.t) + mockLogger := testutil.NewMockLogger(ctrl) + + conn, err := Connect(fc1, ConnOpt.Logger(mockLogger)) + c.Assert(err, IsNil) + c.Check(conn, NotNil) + + c.Assert(conn.log, Equals, mockLogger) +} + +func (s *StompSuite) Test_unsuccessful_connect(c *C) { + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + f2 := frame.New("ERROR", "message", "auth-failed") + err = writer.Write(f2) + c.Assert(err, IsNil) + }() + + conn, err := Connect(fc1) + c.Assert(conn, IsNil) + c.Assert(err, ErrorMatches, "auth-failed") +} + +func (s *StompSuite) Test_successful_connect_and_disconnect(c *C) { + testcases := []struct { + Options []func(*Conn) error + NegotiatedVersion string + ExpectedVersion Version + ExpectedSession string + ExpectedHost string + ExpectedServer string + }{ + { + Options: []func(*Conn) error{ConnOpt.Host("the-server")}, + ExpectedVersion: "1.0", + ExpectedSession: "", + ExpectedHost: "the-server", + ExpectedServer: "some-server/1.1", + }, + { + Options: []func(*Conn) error{}, + NegotiatedVersion: "1.1", + ExpectedVersion: "1.1", + ExpectedSession: "the-session", + ExpectedHost: "the-server", + }, + { + Options: []func(*Conn) error{ConnOpt.Host("xxx")}, + NegotiatedVersion: "1.2", + ExpectedVersion: "1.2", + ExpectedSession: "the-session", + ExpectedHost: "xxx", + }, + } + + for _, tc := range testcases { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + host, _ := f1.Header.Contains("host") + c.Check(host, Equals, tc.ExpectedHost) + connectedFrame := frame.New("CONNECTED") + if tc.NegotiatedVersion != "" { + connectedFrame.Header.Add("version", tc.NegotiatedVersion) + } + if tc.ExpectedSession != "" { + connectedFrame.Header.Add("session", tc.ExpectedSession) + } + if tc.ExpectedServer != "" { + connectedFrame.Header.Add("server", tc.ExpectedServer) + } + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) + + f2, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f2.Command, Equals, "DISCONNECT") + receipt, _ := f2.Header.Contains("receipt") + c.Check(receipt, Equals, "1") + + err = writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + c.Assert(err, IsNil) + + }() + + client, err := Connect(fc1, tc.Options...) + c.Assert(err, IsNil) + c.Assert(client, NotNil) + c.Assert(client.Version(), Equals, tc.ExpectedVersion) + c.Assert(client.Session(), Equals, tc.ExpectedSession) + c.Assert(client.Server(), Equals, tc.ExpectedServer) + + err = client.Disconnect() + c.Assert(err, IsNil) + + <-stop + } +} + +func (s *StompSuite) Test_successful_connect_get_headers(c *C) { + var respHeaders *frame.Header + + testcases := []struct { + Options []func(*Conn) error + Headers map[string]string + }{ + { + Options: []func(*Conn) error{ConnOpt.ResponseHeaders(func(f *frame.Header) { respHeaders = f })}, + Headers: map[string]string{"custom-header": "test", "foo": "bar"}, + }, + } + + for _, tc := range testcases { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + connectedFrame := frame.New("CONNECTED") + for key, value := range tc.Headers { + connectedFrame.Header.Add(key, value) + } + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) + + f2, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f2.Command, Equals, "DISCONNECT") + receipt, _ := f2.Header.Contains("receipt") + c.Check(receipt, Equals, "1") + + err = writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + c.Assert(err, IsNil) + + }() + + client, err := Connect(fc1, tc.Options...) + c.Assert(err, IsNil) + c.Assert(client, NotNil) + c.Assert(respHeaders, NotNil) + for key, value := range tc.Headers { + c.Assert(respHeaders.Get(key), Equals, value) + } + err = client.Disconnect() + c.Assert(err, IsNil) + + <-stop + } +} + +func (s *StompSuite) Test_successful_connect_with_nonstandard_header(c *C) { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + c.Assert(f1.Header.Get("login"), Equals, "guest") + c.Assert(f1.Header.Get("passcode"), Equals, "guest") + c.Assert(f1.Header.Get("host"), Equals, "/") + c.Assert(f1.Header.Get("x-max-length"), Equals, "50") + connectedFrame := frame.New("CONNECTED") + connectedFrame.Header.Add("session", "session-0voRHrG-VbBedx1Gwwb62Q") + connectedFrame.Header.Add("heart-beat", "0,0") + connectedFrame.Header.Add("server", "RabbitMQ/3.2.1") + connectedFrame.Header.Add("version", "1.0") + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) + + f2, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f2.Command, Equals, "DISCONNECT") + receipt, _ := f2.Header.Contains("receipt") + c.Check(receipt, Equals, "1") + + err = writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + c.Assert(err, IsNil) + }() + + client, err := Connect(fc1, + ConnOpt.Login("guest", "guest"), + ConnOpt.Host("/"), + ConnOpt.Header("x-max-length", "50")) + c.Assert(err, IsNil) + c.Assert(client, NotNil) + c.Assert(client.Version(), Equals, V10) + c.Assert(client.Session(), Equals, "session-0voRHrG-VbBedx1Gwwb62Q") + c.Assert(client.Server(), Equals, "RabbitMQ/3.2.1") + + err = client.Disconnect() + c.Assert(err, IsNil) + + <-stop +} + +func (s *StompSuite) Test_connect_not_panic_on_empty_response(c *C) { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + reader := frame.NewReader(fc2) + _, err := reader.Read() + c.Assert(err, IsNil) + _, err = fc2.Write([]byte("\n")) + c.Assert(err, IsNil) + }() + + client, err := Connect(fc1, ConnOpt.Host("the_server")) + c.Assert(err, NotNil) + c.Assert(client, IsNil) + + fc1.Close() + <-stop +} + +func (s *StompSuite) Test_successful_disconnect_with_receipt_timeout(c *C) { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + + defer func() { + fc2.Close() + }() + + go func() { + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + connectedFrame := frame.New("CONNECTED") + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) + }() + + client, err := Connect(fc1, ConnOpt.DisconnectReceiptTimeout(1 * time.Nanosecond)) + c.Assert(err, IsNil) + c.Assert(client, NotNil) + + err = client.Disconnect() + c.Assert(err, Equals, ErrDisconnectReceiptTimeout) + c.Assert(client.closed, Equals, true) +} + +// Sets up a connection for testing +func connectHelper(c *C, version Version) (*Conn, *fakeReaderWriter) { + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + go func() { + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + f2 := frame.New("CONNECTED", "version", version.String()) + err = writer.Write(f2) + c.Assert(err, IsNil) + close(stop) + }() + + conn, err := Connect(fc1) + c.Assert(err, IsNil) + c.Assert(conn, NotNil) + <-stop + return conn, &fakeReaderWriter{ + reader: reader, + writer: writer, + conn: fc2, + } +} + +func (s *StompSuite) Test_subscribe(c *C) { + ackModes := []AckMode{AckAuto, AckClient, AckClientIndividual} + versions := []Version{V10, V11, V12} + + for _, ackMode := range ackModes { + for _, version := range versions { + subscribeHelper(c, ackMode, version) + subscribeHelper(c, ackMode, version, + SubscribeOpt.Header("id", "client-1"), + SubscribeOpt.Header("custom", "true")) + } + } +} + +func subscribeHelper(c *C, ackMode AckMode, version Version, opts ...func(*frame.Frame) error) { + conn, rw := connectHelper(c, version) + stop := make(chan struct{}) + + go func() { + defer func() { + rw.Close() + close(stop) + }() + + f3, err := rw.Read() + c.Assert(err, IsNil) + c.Assert(f3.Command, Equals, "SUBSCRIBE") + + id, ok := f3.Header.Contains("id") + c.Assert(ok, Equals, true) + + destination := f3.Header.Get("destination") + c.Assert(destination, Equals, "/queue/test-1") + ack := f3.Header.Get("ack") + c.Assert(ack, Equals, ackMode.String()) + + for i := 1; i <= 5; i++ { + messageId := fmt.Sprintf("message-%d", i) + bodyText := fmt.Sprintf("Message body %d", i) + f4 := frame.New("MESSAGE", + frame.Subscription, id, + frame.MessageId, messageId, + frame.Destination, destination) + if version == V12 && ackMode.ShouldAck() { + f4.Header.Add(frame.Ack, messageId) + } + f4.Body = []byte(bodyText) + err = rw.Write(f4) + c.Assert(err, IsNil) + + if ackMode.ShouldAck() { + f5, _ := rw.Read() + c.Assert(f5.Command, Equals, "ACK") + if version == V12 { + c.Assert(f5.Header.Get(frame.Id), Equals, messageId) + } else { + c.Assert(f5.Header.Get("subscription"), Equals, id) + c.Assert(f5.Header.Get("message-id"), Equals, messageId) + } + } + } + + f6, _ := rw.Read() + c.Assert(f6.Command, Equals, "UNSUBSCRIBE") + c.Assert(f6.Header.Get(frame.Receipt), Not(Equals), "") + c.Assert(f6.Header.Get(frame.Id), Equals, id) + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f6.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) + + f7, _ := rw.Read() + c.Assert(f7.Command, Equals, "DISCONNECT") + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f7.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) + }() + + var sub *Subscription + var err error + sub, err = conn.Subscribe("/queue/test-1", ackMode, opts...) + + c.Assert(sub, NotNil) + c.Assert(err, IsNil) + + for i := 1; i <= 5; i++ { + msg := <-sub.C + messageId := fmt.Sprintf("message-%d", i) + bodyText := fmt.Sprintf("Message body %d", i) + c.Assert(msg.Subscription, Equals, sub) + c.Assert(msg.Body, DeepEquals, []byte(bodyText)) + c.Assert(msg.Destination, Equals, "/queue/test-1") + c.Assert(msg.Header.Get(frame.MessageId), Equals, messageId) + if version == V12 && ackMode.ShouldAck() { + c.Assert(msg.Header.Get(frame.Ack), Equals, messageId) + } + + c.Assert(msg.ShouldAck(), Equals, ackMode.ShouldAck()) + if msg.ShouldAck() { + err = msg.Conn.Ack(msg) + c.Assert(err, IsNil) + } + } + + err = sub.Unsubscribe(SubscribeOpt.Header("custom", "true")) + c.Assert(err, IsNil) + + err = conn.Disconnect() + c.Assert(err, IsNil) +} + +func (s *StompSuite) TestTransaction(c *C) { + + ackModes := []AckMode{AckAuto, AckClient, AckClientIndividual} + versions := []Version{V10, V11, V12} + aborts := []bool{false, true} + nacks := []bool{false, true} + + for _, ackMode := range ackModes { + for _, version := range versions { + for _, abort := range aborts { + for _, nack := range nacks { + subscribeTransactionHelper(c, ackMode, version, abort, nack) + } + } + } + } +} + +func subscribeTransactionHelper(c *C, ackMode AckMode, version Version, abort bool, nack bool) { + conn, rw := connectHelper(c, version) + stop := make(chan struct{}) + + go func() { + defer func() { + rw.Close() + close(stop) + }() + + f3, err := rw.Read() + c.Assert(err, IsNil) + c.Assert(f3.Command, Equals, "SUBSCRIBE") + id, ok := f3.Header.Contains("id") + c.Assert(ok, Equals, true) + destination := f3.Header.Get("destination") + c.Assert(destination, Equals, "/queue/test-1") + ack := f3.Header.Get("ack") + c.Assert(ack, Equals, ackMode.String()) + + for i := 1; i <= 5; i++ { + messageId := fmt.Sprintf("message-%d", i) + bodyText := fmt.Sprintf("Message body %d", i) + f4 := frame.New("MESSAGE", + frame.Subscription, id, + frame.MessageId, messageId, + frame.Destination, destination) + if version == V12 && ackMode.ShouldAck() { + f4.Header.Add(frame.Ack, messageId) + } + f4.Body = []byte(bodyText) + err = rw.Write(f4) + c.Assert(err, IsNil) + + beginFrame, err := rw.Read() + c.Assert(err, IsNil) + c.Assert(beginFrame, NotNil) + c.Check(beginFrame.Command, Equals, "BEGIN") + tx, ok := beginFrame.Header.Contains(frame.Transaction) + + c.Assert(ok, Equals, true) + + if ackMode.ShouldAck() { + f5, _ := rw.Read() + if nack && version.SupportsNack() { + c.Assert(f5.Command, Equals, "NACK") + } else { + c.Assert(f5.Command, Equals, "ACK") + } + if version == V12 { + c.Assert(f5.Header.Get(frame.Id), Equals, messageId) + } else { + c.Assert(f5.Header.Get("subscription"), Equals, id) + c.Assert(f5.Header.Get("message-id"), Equals, messageId) + } + c.Assert(f5.Header.Get("transaction"), Equals, tx) + } + + sendFrame, _ := rw.Read() + c.Assert(sendFrame, NotNil) + c.Assert(sendFrame.Command, Equals, "SEND") + c.Assert(sendFrame.Header.Get("transaction"), Equals, tx) + + commitFrame, _ := rw.Read() + c.Assert(commitFrame, NotNil) + if abort { + c.Assert(commitFrame.Command, Equals, "ABORT") + } else { + c.Assert(commitFrame.Command, Equals, "COMMIT") + } + c.Assert(commitFrame.Header.Get("transaction"), Equals, tx) + } + + f6, _ := rw.Read() + c.Assert(f6.Command, Equals, "UNSUBSCRIBE") + c.Assert(f6.Header.Get(frame.Receipt), Not(Equals), "") + c.Assert(f6.Header.Get(frame.Id), Equals, id) + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f6.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) + + f7, _ := rw.Read() + c.Assert(f7.Command, Equals, "DISCONNECT") + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f7.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) + }() + + sub, err := conn.Subscribe("/queue/test-1", ackMode) + c.Assert(sub, NotNil) + c.Assert(err, IsNil) + + for i := 1; i <= 5; i++ { + msg := <-sub.C + messageId := fmt.Sprintf("message-%d", i) + bodyText := fmt.Sprintf("Message body %d", i) + c.Assert(msg.Subscription, Equals, sub) + c.Assert(msg.Body, DeepEquals, []byte(bodyText)) + c.Assert(msg.Destination, Equals, "/queue/test-1") + c.Assert(msg.Header.Get(frame.MessageId), Equals, messageId) + + c.Assert(msg.ShouldAck(), Equals, ackMode.ShouldAck()) + tx := msg.Conn.Begin() + c.Assert(tx.Id(), Not(Equals), "") + if msg.ShouldAck() { + if nack && version.SupportsNack() { + err = tx.Nack(msg) + c.Assert(err, IsNil) + } else { + err = tx.Ack(msg) + c.Assert(err, IsNil) + } + } + err = tx.Send("/queue/another-queue", "text/plain", []byte(bodyText)) + c.Assert(err, IsNil) + if abort { + err = tx.Abort() + c.Assert(err, IsNil) + } else { + err = tx.Commit() + c.Assert(err, IsNil) + } + } + + err = sub.Unsubscribe() + c.Assert(err, IsNil) + + err = conn.Disconnect() + c.Assert(err, IsNil) +} + +func (s *StompSuite) TestHeartBeatReadTimeout(c *C) { + conn, rw := createHeartBeatConnection(c, 100, 10000, time.Millisecond) + + go func() { + f1, err := rw.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "SUBSCRIBE") + messageFrame := frame.New("MESSAGE", + "destination", f1.Header.Get("destination"), + "message-id", "1", + "subscription", f1.Header.Get("id")) + messageFrame.Body = []byte("Message body") + err = rw.Write(messageFrame) + c.Assert(err, IsNil) + }() + + sub, err := conn.Subscribe("/queue/test1", AckAuto) + c.Assert(err, IsNil) + c.Check(conn.readTimeout, Equals, 101*time.Millisecond) + //println("read timeout", conn.readTimeout.String()) + + msg, ok := <-sub.C + c.Assert(msg, NotNil) + c.Assert(ok, Equals, true) + + msg, ok = <-sub.C + c.Assert(msg, NotNil) + c.Assert(ok, Equals, true) + c.Assert(msg.Err, NotNil) + c.Assert(msg.Err.Error(), Equals, "read timeout") + + msg, ok = <-sub.C + c.Assert(msg, IsNil) + c.Assert(ok, Equals, false) +} + +func (s *StompSuite) TestHeartBeatWriteTimeout(c *C) { + c.Skip("not finished yet") + conn, rw := createHeartBeatConnection(c, 10000, 100, time.Millisecond*1) + + go func() { + f1, err := rw.Read() + c.Assert(err, IsNil) + c.Assert(f1, IsNil) + + }() + + time.Sleep(250) + err := conn.Disconnect() + c.Assert(err, IsNil) +} + +func createHeartBeatConnection( + c *C, + readTimeout, writeTimeout int, + readTimeoutError time.Duration) (*Conn, *fakeReaderWriter) { + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + go func() { + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + c.Assert(f1.Header.Get("heart-beat"), Equals, "1,1") + f2 := frame.New("CONNECTED", "version", "1.2") + f2.Header.Add("heart-beat", fmt.Sprintf("%d,%d", readTimeout, writeTimeout)) + err = writer.Write(f2) + c.Assert(err, IsNil) + close(stop) + }() + + conn, err := Connect(fc1, + ConnOpt.HeartBeat(time.Millisecond, time.Millisecond), + ConnOpt.HeartBeatError(readTimeoutError)) + c.Assert(conn, NotNil) + c.Assert(err, IsNil) + <-stop + return conn, &fakeReaderWriter{ + reader: reader, + writer: writer, + conn: fc2, + } +} + +// Testing Timeouts when receiving receipts +func sendFrameHelper(f *frame.Frame, c chan *frame.Frame) { + c <- f +} + +//// GIVEN_TheTimeoutIsExceededBeforeTheReceiptIsReceived_WHEN_CallingReadReceiptWithTimeout_THEN_ReturnAnError +func (s *StompSuite) Test_TimeoutTriggers(c *C) { + const timeout = 1 * time.Millisecond + f := frame.Frame{} + request := writeRequest{ + Frame: &f, + C: make(chan *frame.Frame), + } + + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) + + c.Assert(err, NotNil) +} + +//// GIVEN_TheChannelReceivesTheReceiptBeforeTheTimeoutExpires_WHEN_CallingReadReceiptWithTimeout_THEN_DoNotReturnAnError +func (s *StompSuite) Test_ChannelReceviesReceipt(c *C) { + const timeout = 1 * time.Second + f := frame.Frame{} + request := writeRequest{ + Frame: &f, + C: make(chan *frame.Frame), + } + receipt := frame.Frame{ + Command: frame.RECEIPT, + } + + go sendFrameHelper(&receipt, request.C) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) + + c.Assert(err, IsNil) +} + +//// GIVEN_TheChannelReceivesMessage_AND_TheMessageIsNotAReceipt_WHEN_CallingReadReceiptWithTimeout_THEN_ReturnAnError +func (s *StompSuite) Test_ChannelReceviesNonReceipt(c *C) { + const timeout = 1 * time.Second + f := frame.Frame{} + request := writeRequest{ + Frame: &f, + C: make(chan *frame.Frame), + } + receipt := frame.Frame{ + Command: "NOT A RECEIPT", + } + + go sendFrameHelper(&receipt, request.C) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) + + c.Assert(err, NotNil) +} + +//// GIVEN_TheTimeoutIsSetToZero_AND_TheMessageIsReceived_WHEN_CallingReadReceiptWithTimeout_THEN_DoNotReturnAnError +func (s *StompSuite) Test_ZeroTimeout(c *C) { + const timeout = 0 * time.Second + f := frame.Frame{} + request := writeRequest{ + Frame: &f, + C: make(chan *frame.Frame), + } + receipt := frame.Frame{ + Command: frame.RECEIPT, + } + + go sendFrameHelper(&receipt, request.C) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) + + c.Assert(err, IsNil) +} diff --git a/backend/services/stomp/errors.go b/backend/services/stomp/errors.go new file mode 100644 index 0000000..36bee38 --- /dev/null +++ b/backend/services/stomp/errors.go @@ -0,0 +1,57 @@ +package stomp + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// Error values +var ( + ErrInvalidCommand = newErrorMessage("invalid command") + ErrInvalidFrameFormat = newErrorMessage("invalid frame format") + ErrUnsupportedVersion = newErrorMessage("unsupported version") + ErrCompletedTransaction = newErrorMessage("transaction is completed") + ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0") + ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server") + ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto") + ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed") + ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly") + ErrAlreadyClosed = newErrorMessage("connection already closed") + ErrMsgSendTimeout = newErrorMessage("msg send timeout") + ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout") + ErrDisconnectReceiptTimeout = newErrorMessage("disconnect receipt timeout") + ErrNilOption = newErrorMessage("nil option") +) + +// StompError implements the Error interface, and provides +// additional information about a STOMP error. +type Error struct { + Message string + Frame *frame.Frame +} + +func (e Error) Error() string { + return e.Message +} + +func missingHeader(name string) Error { + return newErrorMessage("missing header: " + name) +} + +func newErrorMessage(msg string) Error { + return Error{Message: msg} +} + +func newError(f *frame.Frame) Error { + e := Error{Frame: f} + + if f.Command == frame.ERROR { + if message := f.Header.Get(frame.Message); message != "" { + e.Message = message + } else { + e.Message = "ERROR frame, missing message header" + } + } else { + e.Message = "Unexpected frame: " + f.Command + } + return e +} diff --git a/backend/services/stomp/example_test.go b/backend/services/stomp/example_test.go new file mode 100644 index 0000000..4f095d9 --- /dev/null +++ b/backend/services/stomp/example_test.go @@ -0,0 +1,242 @@ +package stomp_test + +import ( + "fmt" + "net" + "time" + + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" +) + +func ExampleConn_Send(c *stomp.Conn) error { + // send with receipt and an optional header + err := c.Send( + "/queue/test-1", // destination + "text/plain", // content-type + []byte("Message number 1"), // body + stomp.SendOpt.Receipt, + stomp.SendOpt.Header("expires", "2049-12-31 23:59:59")) + if err != nil { + return err + } + + // send with no receipt and no optional headers + err = c.Send("/queue/test-2", "application/xml", + []byte("hello")) + if err != nil { + return err + } + + return nil +} + +// Creates a new Header. +func ExampleNewHeader() { + /* + Creates a header that looks like the following: + + login:scott + passcode:tiger + host:stompserver + accept-version:1.1,1.2 + */ + h := frame.NewHeader( + "login", "scott", + "passcode", "tiger", + "host", "stompserver", + "accept-version", "1.1,1.2") + doSomethingWith(h) +} + +// Creates a STOMP frame. +func ExampleNewFrame() { + /* + Creates a STOMP frame that looks like the following: + + CONNECT + login:scott + passcode:tiger + host:stompserver + accept-version:1.1,1.2 + + ^@ + */ + f := frame.New("CONNECT", + "login", "scott", + "passcode", "tiger", + "host", "stompserver", + "accept-version", "1.1,1.2") + doSomethingWith(f) +} + +func doSomethingWith(f ...interface{}) { + +} + +func doAnotherThingWith(f interface{}, g interface{}) { + +} + +func ExampleConn_Subscribe_1() error { + conn, err := stomp.Dial("tcp", "localhost:61613") + if err != nil { + return err + } + + sub, err := conn.Subscribe("/queue/test-2", stomp.AckClient) + if err != nil { + return err + } + + // receive 5 messages and then quit + for i := 0; i < 5; i++ { + msg := <-sub.C + if msg.Err != nil { + return msg.Err + } + + doSomethingWith(msg) + + // acknowledge the message + err = conn.Ack(msg) + if err != nil { + return err + } + } + + err = sub.Unsubscribe() + if err != nil { + return err + } + + return conn.Disconnect() +} + +// Example of creating subscriptions with various options. +func ExampleConn_Subscribe_2(c *stomp.Conn) error { + // Subscribe to queue with automatic acknowledgement + sub1, err := c.Subscribe("/queue/test-1", stomp.AckAuto) + if err != nil { + return err + } + + // Subscribe to queue with client acknowledgement and a custom header value + sub2, err := c.Subscribe("/queue/test-2", stomp.AckClient, + stomp.SubscribeOpt.Header("x-custom-header", "some-value")) + if err != nil { + return err + } + + doSomethingWith(sub1, sub2) + + return nil +} + +func ExampleTransaction() error { + conn, err := stomp.Dial("tcp", "localhost:61613") + if err != nil { + return err + } + defer conn.Disconnect() + + sub, err := conn.Subscribe("/queue/test-2", stomp.AckClient) + if err != nil { + return err + } + + // receive 5 messages and then quit + for i := 0; i < 5; i++ { + msg := <-sub.C + if msg.Err != nil { + return msg.Err + } + + tx := conn.Begin() + + doAnotherThingWith(msg, tx) + + tx.Send("/queue/another-one", "text/plain", + []byte(fmt.Sprintf("Message #%d", i)), nil) + + // acknowledge the message + err = tx.Ack(msg) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + } + + err = sub.Unsubscribe() + if err != nil { + return err + } + + return nil +} + +// Example of connecting to a STOMP server using an existing network connection. +func ExampleConnect() error { + netConn, err := net.DialTimeout("tcp", "stomp.server.com:61613", 10*time.Second) + if err != nil { + return err + } + + stompConn, err := stomp.Connect(netConn) + if err != nil { + return err + } + + defer stompConn.Disconnect() + + doSomethingWith(stompConn) + return nil +} + +// Connect to a STOMP server using default options. +func ExampleDial_1() error { + conn, err := stomp.Dial("tcp", "192.168.1.1:61613") + if err != nil { + return err + } + + err = conn.Send( + "/queue/test-1", // destination + "text/plain", // content-type + []byte("Test message #1")) // body + if err != nil { + return err + } + + return conn.Disconnect() +} + +// Connect to a STOMP server that requires authentication. In addition, +// we are only prepared to use STOMP protocol version 1.1 or 1.2, and +// the virtual host is named "dragon". In this example the STOMP +// server also accepts a non-standard header called 'nonce'. +func ExampleDial_2() error { + conn, err := stomp.Dial("tcp", "192.168.1.1:61613", + stomp.ConnOpt.Login("scott", "leopard"), + stomp.ConnOpt.AcceptVersion(stomp.V11), + stomp.ConnOpt.AcceptVersion(stomp.V12), + stomp.ConnOpt.Host("dragon"), + stomp.ConnOpt.Header("nonce", "B256B26D320A")) + if err != nil { + return err + } + + err = conn.Send( + "/queue/test-1", // destination + "text/plain", // content-type + []byte("Test message #1")) // body + if err != nil { + return err + } + + return conn.Disconnect() +} diff --git a/backend/services/stomp/examples/client_test/main.go b/backend/services/stomp/examples/client_test/main.go new file mode 100644 index 0000000..7fd61c8 --- /dev/null +++ b/backend/services/stomp/examples/client_test/main.go @@ -0,0 +1,98 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/go-stomp/stomp/v3" +) + +const defaultPort = ":61613" + +var serverAddr = flag.String("server", "localhost:61613", "STOMP server endpoint") +var messageCount = flag.Int("count", 10, "Number of messages to send/receive") +var queueName = flag.String("queue", "/queue/client_test", "Destination queue") +var helpFlag = flag.Bool("help", false, "Print help text") +var stop = make(chan bool) + +// these are the default options that work with RabbitMQ +var options []func(*stomp.Conn) error = []func(*stomp.Conn) error{ + stomp.ConnOpt.Login("guest", "guest"), + stomp.ConnOpt.Host("/"), +} + +func main() { + flag.Parse() + if *helpFlag { + fmt.Fprintf(os.Stderr, "Usage of %s\n", os.Args[0]) + flag.PrintDefaults() + os.Exit(1) + } + + subscribed := make(chan bool) + go recvMessages(subscribed) + + // wait until we know the receiver has subscribed + <-subscribed + + go sendMessages() + + <-stop + <-stop +} + +func sendMessages() { + defer func() { + stop <- true + }() + + conn, err := stomp.Dial("tcp", *serverAddr, options...) + if err != nil { + println("cannot connect to server", err.Error()) + return + } + + for i := 1; i <= *messageCount; i++ { + text := fmt.Sprintf("Message #%d", i) + err = conn.Send(*queueName, "text/plain", + []byte(text), nil) + if err != nil { + println("failed to send to server", err) + return + } + } + println("sender finished") +} + +func recvMessages(subscribed chan bool) { + defer func() { + stop <- true + }() + + conn, err := stomp.Dial("tcp", *serverAddr, options...) + + if err != nil { + println("cannot connect to server", err.Error()) + return + } + + sub, err := conn.Subscribe(*queueName, stomp.AckAuto) + if err != nil { + println("cannot subscribe to", *queueName, err.Error()) + return + } + close(subscribed) + + for i := 1; i <= *messageCount; i++ { + msg := <-sub.C + expectedText := fmt.Sprintf("Message #%d", i) + actualText := string(msg.Body) + if expectedText != actualText { + println("Expected:", expectedText) + println("Actual:", actualText) + } + } + println("receiver finished") + +} diff --git a/backend/services/stomp/frame/ack.go b/backend/services/stomp/frame/ack.go new file mode 100644 index 0000000..f36de4a --- /dev/null +++ b/backend/services/stomp/frame/ack.go @@ -0,0 +1,8 @@ +package frame + +// Valid values for the "ack" header entry. +const ( + AckAuto = "auto" // Client does not send ACK + AckClient = "client" // Client sends ACK/NACK + AckClientIndividual = "client-individual" // Client sends ACK/NACK for individual messages +) diff --git a/backend/services/stomp/frame/command.go b/backend/services/stomp/frame/command.go new file mode 100644 index 0000000..f1ef76f --- /dev/null +++ b/backend/services/stomp/frame/command.go @@ -0,0 +1,26 @@ +package frame + +// STOMP frame commands. Used upper case naming +// convention to avoid clashing with STOMP header names. +const ( + // Connect commands. + CONNECT = "CONNECT" + STOMP = "STOMP" + CONNECTED = "CONNECTED" + + // Client commands. + SEND = "SEND" + SUBSCRIBE = "SUBSCRIBE" + UNSUBSCRIBE = "UNSUBSCRIBE" + ACK = "ACK" + NACK = "NACK" + BEGIN = "BEGIN" + COMMIT = "COMMIT" + ABORT = "ABORT" + DISCONNECT = "DISCONNECT" + + // Server commands. + MESSAGE = "MESSAGE" + RECEIPT = "RECEIPT" + ERROR = "ERROR" +) diff --git a/backend/services/stomp/frame/encode.go b/backend/services/stomp/frame/encode.go new file mode 100644 index 0000000..ecd187c --- /dev/null +++ b/backend/services/stomp/frame/encode.go @@ -0,0 +1,34 @@ +package frame + +import ( + "strings" + "unsafe" +) + +var ( + replacerForEncodeValue = strings.NewReplacer( + "\\", "\\\\", + "\r", "\\r", + "\n", "\\n", + ":", "\\c", + ) + replacerForUnencodeValue = strings.NewReplacer( + "\\r", "\r", + "\\n", "\n", + "\\c", ":", + "\\\\", "\\", + ) +) + +// Reduce one allocation on copying bytes to string +func bytesToString(b []byte) string { + /* #nosec G103 */ + return *(*string)(unsafe.Pointer(&b)) +} + +// Unencodes a header value using STOMP value encoding +// TODO: return error if invalid sequences found (eg "\t") +func unencodeValue(b []byte) (string, error) { + s := replacerForUnencodeValue.Replace(bytesToString(b)) + return s, nil +} diff --git a/backend/services/stomp/frame/encode_test.go b/backend/services/stomp/frame/encode_test.go new file mode 100644 index 0000000..99a1ab9 --- /dev/null +++ b/backend/services/stomp/frame/encode_test.go @@ -0,0 +1,15 @@ +package frame + +import ( + . "gopkg.in/check.v1" +) + +type EncodeSuite struct{} + +var _ = Suite(&EncodeSuite{}) + +func (s *EncodeSuite) TestUnencodeValue(c *C) { + val, err := unencodeValue([]byte(`Contains\r\nNewLine and \c colon and \\ backslash`)) + c.Check(err, IsNil) + c.Check(val, Equals, "Contains\r\nNewLine and : colon and \\ backslash") +} diff --git a/backend/services/stomp/frame/errors.go b/backend/services/stomp/frame/errors.go new file mode 100644 index 0000000..672b918 --- /dev/null +++ b/backend/services/stomp/frame/errors.go @@ -0,0 +1,9 @@ +package frame + +import ( + "errors" +) + +var ( + ErrInvalidHeartBeat = errors.New("invalid heart-beat") +) diff --git a/backend/services/stomp/frame/frame.go b/backend/services/stomp/frame/frame.go new file mode 100644 index 0000000..a94b5f1 --- /dev/null +++ b/backend/services/stomp/frame/frame.go @@ -0,0 +1,38 @@ +/* +Package frame provides functionality for manipulating STOMP frames. +*/ +package frame + +// A Frame represents a STOMP frame. A frame consists of a command +// followed by a collection of header entries, and then an optional +// body. +type Frame struct { + Command string + Header *Header + Body []byte +} + +// New creates a new STOMP frame with the specified command and headers. +// The headers should contain an even number of entries. Each even index is +// the header name, and the odd indexes are the assocated header values. +func New(command string, headers ...string) *Frame { + f := &Frame{Command: command, Header: &Header{}} + for index := 0; index < len(headers); index += 2 { + f.Header.Add(headers[index], headers[index+1]) + } + return f +} + +// Clone creates a deep copy of the frame and its header. The cloned +// frame shares the body with the original frame. +func (f *Frame) Clone() *Frame { + fc := &Frame{Command: f.Command} + if f.Header != nil { + fc.Header = f.Header.Clone() + } + if f.Body != nil { + fc.Body = make([]byte, len(f.Body)) + copy(fc.Body, f.Body) + } + return fc +} diff --git a/backend/services/stomp/frame/frame_test.go b/backend/services/stomp/frame/frame_test.go new file mode 100644 index 0000000..1f31a1e --- /dev/null +++ b/backend/services/stomp/frame/frame_test.go @@ -0,0 +1,67 @@ +package frame + +import ( + "testing" + + . "gopkg.in/check.v1" +) + +func TestFrame(t *testing.T) { + TestingT(t) +} + +type FrameSuite struct{} + +var _ = Suite(&FrameSuite{}) + +func (s *FrameSuite) TestNew(c *C) { + f := New("CCC") + c.Check(f.Header.Len(), Equals, 0) + c.Check(f.Command, Equals, "CCC") + + f = New("DDDD", "abc", "def") + c.Check(f.Header.Len(), Equals, 1) + k, v := f.Header.GetAt(0) + c.Check(k, Equals, "abc") + c.Check(v, Equals, "def") + c.Check(f.Command, Equals, "DDDD") + + f = New("EEEEEEE", "abc", "def", "hij", "klm") + c.Check(f.Command, Equals, "EEEEEEE") + c.Check(f.Header.Len(), Equals, 2) + k, v = f.Header.GetAt(0) + c.Check(k, Equals, "abc") + c.Check(v, Equals, "def") + k, v = f.Header.GetAt(1) + c.Check(k, Equals, "hij") + c.Check(v, Equals, "klm") +} + +func (s *FrameSuite) TestClone(c *C) { + f1 := &Frame{ + Command: "AAAA", + } + + f2 := f1.Clone() + c.Check(f2.Command, Equals, f1.Command) + c.Check(f2.Header, IsNil) + c.Check(f2.Body, IsNil) + + f1.Header = NewHeader("aaa", "1", "bbb", "2", "ccc", "3") + + f2 = f1.Clone() + c.Check(f2.Header.Len(), Equals, f1.Header.Len()) + for i := 0; i < f1.Header.Len(); i++ { + k1, v1 := f1.Header.GetAt(i) + k2, v2 := f2.Header.GetAt(i) + c.Check(k1, Equals, k2) + c.Check(v1, Equals, v2) + } + + f1.Body = []byte{1, 2, 3, 4, 5, 6, 5, 4, 77, 88, 99, 0xaa, 0xbb, 0xcc, 0xff} + f2 = f1.Clone() + c.Check(len(f2.Body), Equals, len(f1.Body)) + for i := 0; i < len(f1.Body); i++ { + c.Check(f1.Body[i], Equals, f2.Body[i]) + } +} diff --git a/backend/services/stomp/frame/header.go b/backend/services/stomp/frame/header.go new file mode 100644 index 0000000..6eb60e2 --- /dev/null +++ b/backend/services/stomp/frame/header.go @@ -0,0 +1,192 @@ +package frame + +import ( + "strconv" +) + +// STOMP header names. Some of the header +// names have commands with the same name +// (eg Ack, Message, Receipt). Commands use +// an upper-case naming convention, header +// names use pascal-case naming convention. +const ( + ContentLength = "content-length" + ContentType = "content-type" + Receipt = "receipt" + AcceptVersion = "accept-version" + Host = "host" + Version = "version" + Login = "login" + Passcode = "passcode" + HeartBeat = "heart-beat" + Session = "session" + Server = "server" + Destination = "destination" + Id = "id" + Ack = "ack" + Transaction = "transaction" + ReceiptId = "receipt-id" + Subscription = "subscription" + MessageId = "message-id" + Message = "message" + /* TR-369 section 4.4.2.1 [Subscribing a USP Endpoint to a STOMP Destination] */ + /* + R-STOMP.14: USP Agents that receive a subscribe-dest STOMP Header in the CONNECTED + frame MUST use that STOMP destination in the destination STOMP header when sending a + SUBSCRIBE frame. + */ + SubscribeDest = "subscribe-dest" +) + +// A Header represents the header part of a STOMP frame. +// The header in a STOMP frame consists of a list of header entries. +// Each header entry is a key/value pair of strings. +// +// Normally a STOMP header only has one header entry for a given key, but +// the STOMP standard does allow for multiple header entries with the same +// key. In this case, the first header entry contains the value, and any +// subsequent header entries with the same key are ignored. +// +// Example header containing 6 header entries. Note that the second +// header entry with the key "comment" would be ignored. +// +// login:scott +// passcode:tiger +// host:stompserver +// accept-version:1.0,1.1,1.2 +// comment:some comment +// comment:another comment +type Header struct { + slice []string +} + +// NewHeader creates a new Header and populates it with header entries. +// This function expects an even number of strings as parameters. The +// even numbered indices are keys and the odd indices are values. See +// the example for more information. +func NewHeader(headerEntries ...string) *Header { + h := &Header{} + h.slice = append(h.slice, headerEntries...) + if len(h.slice)%2 != 0 { + h.slice = append(h.slice, "") + } + return h +} + +// Add adds the key, value pair to the header. +func (h *Header) Add(key, value string) { + h.slice = append(h.slice, key, value) +} + +// AddHeader adds all of the key value pairs in header to h. +func (h *Header) AddHeader(header *Header) { + if header != nil { + for i := 0; i < header.Len(); i++ { + key, value := header.GetAt(i) + h.Add(key, value) + } + } +} + +// Set replaces the value of any existing header entry with the specified key. +// If there is no existing header entry with the specified key, a new +// header entry is added. +func (h *Header) Set(key, value string) { + if i, ok := h.index(key); ok { + h.slice[i+1] = value + } else { + h.slice = append(h.slice, key, value) + } +} + +// Get gets the first value associated with the given key. +// If there are no values associated with the key, Get returns "". +func (h *Header) Get(key string) string { + value, _ := h.Contains(key) + return value +} + +// GetAll returns all of the values associated with a given key. +// Normally there is only one header entry per key, but it is permitted +// to have multiple entries according to the STOMP standard. +func (h *Header) GetAll(key string) []string { + var values []string + for i := 0; i < len(h.slice); i += 2 { + if h.slice[i] == key { + values = append(values, h.slice[i+1]) + } + } + return values +} + +// Returns the header name and value at the specified index in +// the collection. The index should be in the range 0 <= index < Len(), +// a panic will occur if it is outside this range. +func (h *Header) GetAt(index int) (key, value string) { + index *= 2 + return h.slice[index], h.slice[index+1] +} + +// Contains gets the first value associated with the given key, +// and also returns a bool indicating whether the header entry +// exists. +// +// If there are no values associated with the key, Get returns "" +// for the value, and ok is false. +func (h *Header) Contains(key string) (value string, ok bool) { + var i int + if i, ok = h.index(key); ok { + value = h.slice[i+1] + } + return +} + +// Del deletes all header entries with the specified key. +func (h *Header) Del(key string) { + for i, ok := h.index(key); ok; i, ok = h.index(key) { + h.slice = append(h.slice[:i], h.slice[i+2:]...) + } +} + +// Len returns the number of header entries in the header. +func (h *Header) Len() int { + return len(h.slice) / 2 +} + +// Clone returns a deep copy of a Header. +func (h *Header) Clone() *Header { + hc := &Header{slice: make([]string, len(h.slice))} + copy(hc.slice, h.slice) + return hc +} + +// ContentLength returns the value of the "content-length" header entry. +// If the "content-length" header is missing, then ok is false. If the +// "content-length" entry is present but is not a valid non-negative integer +// then err is non-nil. +func (h *Header) ContentLength() (value int, ok bool, err error) { + text, ok := h.Contains(ContentLength) + if !ok { + return 0, false, nil + } + + n, err := strconv.ParseUint(text, 10, 32) + if err != nil { + return 0, true, err + } + + value = int(n) + ok = true + return value, ok, nil +} + +// Returns the index of a header key in Headers, and a bool to indicate +// whether it was found or not. +func (h *Header) index(key string) (int, bool) { + for i := 0; i < len(h.slice); i += 2 { + if h.slice[i] == key { + return i, true + } + } + return -1, false +} diff --git a/backend/services/stomp/frame/header_test.go b/backend/services/stomp/frame/header_test.go new file mode 100644 index 0000000..11330db --- /dev/null +++ b/backend/services/stomp/frame/header_test.go @@ -0,0 +1,69 @@ +package frame + +import ( + . "gopkg.in/check.v1" +) + +func (s *FrameSuite) TestHeaderGetSetAddDel(c *C) { + h := &Header{} + c.Assert(h.Get("xxx"), Equals, "") + h.Add("xxx", "yyy") + c.Assert(h.Get("xxx"), Equals, "yyy") + h.Add("xxx", "zzz") + c.Assert(h.GetAll("xxx"), DeepEquals, []string{"yyy", "zzz"}) + h.Set("xxx", "111") + c.Assert(h.Get("xxx"), Equals, "111") + h.Del("xxx") + c.Assert(h.Get("xxx"), Equals, "") +} + +func (s *FrameSuite) TestHeaderClone(c *C) { + h := Header{} + h.Set("xxx", "yyy") + h.Set("yyy", "zzz") + + hc := h.Clone() + h.Del("xxx") + h.Del("yyy") + c.Assert(hc.Get("xxx"), Equals, "yyy") + c.Assert(hc.Get("yyy"), Equals, "zzz") +} + +func (s *FrameSuite) TestHeaderContains(c *C) { + h := NewHeader("xxx", "yyy", "zzz", "aaa", "xxx", "ccc") + v, ok := h.Contains("xxx") + c.Assert(v, Equals, "yyy") + c.Assert(ok, Equals, true) + + v, ok = h.Contains("123") + c.Assert(v, Equals, "") + c.Assert(ok, Equals, false) +} + +func (s *FrameSuite) TestContentLength(c *C) { + h := NewHeader("xxx", "yy", "content-length", "202", "zz", "123") + cl, ok, err := h.ContentLength() + c.Assert(cl, Equals, 202) + c.Assert(ok, Equals, true) + c.Assert(err, Equals, nil) + + h.Set("content-length", "twenty") + cl, ok, err = h.ContentLength() + c.Assert(cl, Equals, 0) + c.Assert(ok, Equals, true) + c.Assert(err, NotNil) + + h.Del("content-length") + cl, ok, err = h.ContentLength() + c.Assert(cl, Equals, 0) + c.Assert(ok, Equals, false) + c.Assert(err, IsNil) +} + +func (s *FrameSuite) TestLit(c *C) { + _ = Frame{ + Command: "CONNECT", + Header: NewHeader("login", "xxx", "passcode", "yyy"), + Body: []byte{1, 2, 3, 4}, + } +} diff --git a/backend/services/stomp/frame/heartbeat.go b/backend/services/stomp/frame/heartbeat.go new file mode 100644 index 0000000..2f0eb86 --- /dev/null +++ b/backend/services/stomp/frame/heartbeat.go @@ -0,0 +1,44 @@ +package frame + +import ( + "math" + "regexp" + "strconv" + "strings" + "time" +) + +var ( + // Regexp for heart-beat header value + heartBeatRegexp = regexp.MustCompile("^[0-9]+,[0-9]+$") +) + +const ( + // Maximum number of milliseconds that can be represented + // in a time.Duration. + maxMilliseconds = math.MaxInt64 / int64(time.Millisecond) +) + +// ParseHeartBeat parses the value of a STOMP heart-beat entry and +// returns two time durations. Returns an error if the heart-beat +// value is not in the correct format, or if the time durations are +// too big to be represented by the time.Duration type. +func ParseHeartBeat(heartBeat string) (time.Duration, time.Duration, error) { + if !heartBeatRegexp.MatchString(heartBeat) { + return 0, 0, ErrInvalidHeartBeat + } + slice := strings.Split(heartBeat, ",") + value1, err := strconv.ParseInt(slice[0], 10, 64) + if err != nil { + return 0, 0, ErrInvalidHeartBeat + } + value2, err := strconv.ParseInt(slice[1], 10, 64) + if err != nil { + return 0, 0, ErrInvalidHeartBeat + } + if value1 > maxMilliseconds || value2 > maxMilliseconds { + return 0, 0, ErrInvalidHeartBeat + } + return time.Duration(value1) * time.Millisecond, + time.Duration(value2) * time.Millisecond, nil +} diff --git a/backend/services/stomp/frame/heartbeat_test.go b/backend/services/stomp/frame/heartbeat_test.go new file mode 100644 index 0000000..6dff501 --- /dev/null +++ b/backend/services/stomp/frame/heartbeat_test.go @@ -0,0 +1,77 @@ +package frame + +import ( + "time" + + . "gopkg.in/check.v1" +) + +func (s *FrameSuite) TestParseHeartBeat(c *C) { + testCases := []struct { + Input string + ExpectedDuration1 time.Duration + ExpectedDuration2 time.Duration + ExpectError bool + ExpectedError error + }{ + { + Input: "0,0", + ExpectedDuration1: 0, + ExpectedDuration2: 0, + }, + { + Input: "20000,60000", + ExpectedDuration1: 20 * time.Second, + ExpectedDuration2: time.Minute, + }, + { + Input: "86400000,31536000000", + ExpectedDuration1: 24 * time.Hour, + ExpectedDuration2: 365 * 24 * time.Hour, + }, + { + Input: "20r000,60000", + ExpectedDuration1: 0, + ExpectedDuration2: 0, + ExpectedError: ErrInvalidHeartBeat, + }, + { + Input: "99999999999999999999,60000", + ExpectedDuration1: 0, + ExpectedDuration2: 0, + ExpectedError: ErrInvalidHeartBeat, + }, + { + Input: "60000,99999999999999999999", + ExpectedDuration1: 0, + ExpectedDuration2: 0, + ExpectedError: ErrInvalidHeartBeat, + }, + { + Input: "-60000,60000", + ExpectedDuration1: 0, + ExpectedDuration2: 0, + ExpectedError: ErrInvalidHeartBeat, + }, + { + Input: "60000,-60000", + ExpectedDuration1: 0, + ExpectedDuration2: 0, + ExpectedError: ErrInvalidHeartBeat, + }, + } + + for _, tc := range testCases { + d1, d2, err := ParseHeartBeat(tc.Input) + c.Check(d1, Equals, tc.ExpectedDuration1) + c.Check(d2, Equals, tc.ExpectedDuration2) + if tc.ExpectError || tc.ExpectedError != nil { + c.Check(err, NotNil) + if tc.ExpectedError != nil { + c.Check(err, Equals, tc.ExpectedError) + } + } else { + c.Check(err, IsNil) + } + } +} diff --git a/backend/services/stomp/frame/reader.go b/backend/services/stomp/frame/reader.go new file mode 100644 index 0000000..ded3dfe --- /dev/null +++ b/backend/services/stomp/frame/reader.go @@ -0,0 +1,157 @@ +package frame + +import ( + "bufio" + "bytes" + "errors" + "io" +) + +const ( + bufferSize = 4096 + newline = byte(10) + cr = byte(13) + colon = byte(58) + nullByte = byte(0) +) + +var ( + ErrInvalidCommand = errors.New("invalid command") + ErrInvalidFrameFormat = errors.New("invalid frame format") +) + +// The Reader type reads STOMP frames from an underlying io.Reader. +// The reader is buffered, and the size of the buffer is the maximum +// size permitted for the STOMP frame command and header section. +// A STOMP frame is rejected if its command and header section exceed +// the buffer size. +type Reader struct { + reader *bufio.Reader +} + +// NewReader creates a Reader with the default underlying buffer size. +func NewReader(reader io.Reader) *Reader { + return NewReaderSize(reader, bufferSize) +} + +// NewReaderSize creates a Reader with an underlying bufferSize +// of the specified size. +func NewReaderSize(reader io.Reader, bufferSize int) *Reader { + return &Reader{reader: bufio.NewReaderSize(reader, bufferSize)} +} + +// Read a STOMP frame from the input. If the input contains one +// or more heart-beat characters and no frame, then nil will +// be returned for the frame. Calling programs should always check +// for a nil frame. +func (r *Reader) Read() (*Frame, error) { + commandSlice, err := r.readLine() + if err != nil { + return nil, err + } + + if len(commandSlice) == 0 { + // received a heart-beat newline char (or cr-lf) + return nil, nil + } + + f := New(string(commandSlice)) + //println("RX:", f.Command) + switch f.Command { + // TODO(jpj): Is it appropriate to perform validation on the + // command at this point. Probably better to validate higher up, + // this way this type can be useful for any other non-STOMP protocols + // which happen to use the same frame format. + case CONNECT, STOMP, SEND, SUBSCRIBE, + UNSUBSCRIBE, ACK, NACK, BEGIN, + COMMIT, ABORT, DISCONNECT, CONNECTED, + MESSAGE, RECEIPT, ERROR: + // valid command + default: + return nil, ErrInvalidCommand + } + + // read headers + for { + headerSlice, err := r.readLine() + if err != nil { + return nil, err + } + + if len(headerSlice) == 0 { + // empty line means end of headers + break + } + + index := bytes.IndexByte(headerSlice, colon) + if index <= 0 { + // colon is missing or header name is zero length + return nil, ErrInvalidFrameFormat + } + + name, err := unencodeValue(headerSlice[0:index]) + if err != nil { + return nil, err + } + value, err := unencodeValue(headerSlice[index+1:]) + if err != nil { + return nil, err + } + + //println(" ", name, ":", value) + + f.Header.Add(name, value) + } + + // get content length from the headers + if contentLength, ok, err := f.Header.ContentLength(); err != nil { + // happens if the content is malformed + return nil, err + } else if ok { + // content length specified in the header, so use that + f.Body = make([]byte, contentLength) + for bytesRead := 0; bytesRead < contentLength; { + n, err := r.reader.Read(f.Body[bytesRead:contentLength]) + if err != nil { + return nil, err + } + bytesRead += n + } + + // read the next byte and verify that it is a null byte + terminator, err := r.reader.ReadByte() + if err != nil { + return nil, err + } + if terminator != 0 { + return nil, ErrInvalidFrameFormat + } + } else { + f.Body, err = r.reader.ReadBytes(nullByte) + if err != nil { + return nil, err + } + // remove trailing null + f.Body = f.Body[0 : len(f.Body)-1] + } + + // pass back frame + return f, nil +} + +// read one line from input and strip off terminating LF or terminating CR-LF +func (r *Reader) readLine() (line []byte, err error) { + line, err = r.reader.ReadBytes(newline) + if err != nil { + return + } + + switch { + case bytes.HasSuffix(line, crlfSlice): + line = line[0 : len(line)-len(crlfSlice)] + case bytes.HasSuffix(line, newlineSlice): + line = line[0 : len(line)-len(newlineSlice)] + } + + return +} diff --git a/backend/services/stomp/frame/reader_test.go b/backend/services/stomp/frame/reader_test.go new file mode 100644 index 0000000..d3baabf --- /dev/null +++ b/backend/services/stomp/frame/reader_test.go @@ -0,0 +1,140 @@ +package frame + +import ( + "io" + "strings" + "testing/iotest" + + . "gopkg.in/check.v1" +) + +type ReaderSuite struct{} + +var _ = Suite(&ReaderSuite{}) + +func (s *ReaderSuite) TestConnect(c *C) { + reader := NewReader(strings.NewReader("CONNECT\nlogin:xxx\npasscode:yyy\n\n\x00")) + + frame, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(frame, NotNil) + c.Assert(len(frame.Body), Equals, 0) + + // ensure we are at the end of input + frame, err = reader.Read() + c.Assert(frame, IsNil) + c.Assert(err, Equals, io.EOF) +} + +func (s *ReaderSuite) TestMultipleReads(c *C) { + text := "SEND\ndestination:xxx\n\nPayload\x00\n" + + "SEND\ndestination:yyy\ncontent-length:12\n" + + "dodgy\\c\\n\\cheader:dodgy\\c\\n\\r\\nvalue\\ \\\n\n" + + "123456789AB\x00\x00" + + ioreaders := []io.Reader{ + strings.NewReader(text), + iotest.DataErrReader(strings.NewReader(text)), + iotest.HalfReader(strings.NewReader(text)), + iotest.OneByteReader(strings.NewReader(text)), + } + + for _, ioreader := range ioreaders { + // uncomment the following line to view the bytes being read + //ioreader = iotest.NewReadLogger("RX", ioreader) + reader := NewReader(ioreader) + frame, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(frame, NotNil) + c.Assert(frame.Command, Equals, "SEND") + c.Assert(frame.Header.Len(), Equals, 1) + v := frame.Header.Get("destination") + c.Assert(v, Equals, "xxx") + c.Assert(string(frame.Body), Equals, "Payload") + + // now read a heart-beat from the input + frame, err = reader.Read() + c.Assert(err, IsNil) + c.Assert(frame, IsNil) + + // this frame has content-length + frame, err = reader.Read() + c.Assert(err, IsNil) + c.Assert(frame, NotNil) + c.Assert(frame.Command, Equals, "SEND") + c.Assert(frame.Header.Len(), Equals, 3) + v = frame.Header.Get("destination") + c.Assert(v, Equals, "yyy") + n, ok, err := frame.Header.ContentLength() + c.Assert(n, Equals, 12) + c.Assert(ok, Equals, true) + c.Assert(err, IsNil) + k, v := frame.Header.GetAt(2) + c.Assert(k, Equals, "dodgy:\n:header") + c.Assert(v, Equals, "dodgy:\n\r\nvalue\\ \\") + c.Assert(string(frame.Body), Equals, "123456789AB\x00") + + // ensure we are at the end of input + frame, err = reader.Read() + c.Assert(frame, IsNil) + c.Assert(err, Equals, io.EOF) + } +} + +func (s *ReaderSuite) TestSendWithContentLength(c *C) { + reader := NewReader(strings.NewReader("SEND\ndestination:xxx\ncontent-length:5\n\n\x00\x01\x02\x03\x04\x00")) + + frame, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(frame, NotNil) + c.Assert(frame.Command, Equals, "SEND") + c.Assert(frame.Header.Len(), Equals, 2) + v := frame.Header.Get("destination") + c.Assert(v, Equals, "xxx") + c.Assert(frame.Body, DeepEquals, []byte{0x00, 0x01, 0x02, 0x03, 0x04}) + + // ensure we are at the end of input + frame, err = reader.Read() + c.Assert(frame, IsNil) + c.Assert(err, Equals, io.EOF) +} + +func (s *ReaderSuite) TestInvalidCommand(c *C) { + reader := NewReader(strings.NewReader("sEND\ndestination:xxx\ncontent-length:5\n\n\x00\x01\x02\x03\x04\x00")) + + frame, err := reader.Read() + c.Check(frame, IsNil) + c.Assert(err, NotNil) + c.Check(err.Error(), Equals, "invalid command") +} + +func (s *ReaderSuite) TestMissingNull(c *C) { + reader := NewReader(strings.NewReader("SEND\ndeestination:xxx\ncontent-length:5\n\n\x00\x01\x02\x03\x04\n")) + + f, err := reader.Read() + c.Check(f, IsNil) + c.Assert(err, NotNil) + c.Check(err.Error(), Equals, "invalid frame format") +} + +func (s *ReaderSuite) TestSubscribeWithoutId(c *C) { + c.Skip("TODO: implement validate") + + reader := NewReader(strings.NewReader("SUBSCRIBE\ndestination:xxx\nIId:7\n\n\x00")) + + frame, err := reader.Read() + c.Check(frame, IsNil) + c.Assert(err, NotNil) + c.Check(err.Error(), Equals, "missing header: id") +} + +func (s *ReaderSuite) TestUnsubscribeWithoutId(c *C) { + c.Skip("TODO: implement validate") + + reader := NewReader(strings.NewReader("UNSUBSCRIBE\nIId:7\n\n\x00")) + + frame, err := reader.Read() + c.Check(frame, IsNil) + c.Assert(err, NotNil) + c.Check(err.Error(), Equals, "missing header: id") +} diff --git a/backend/services/stomp/frame/writer.go b/backend/services/stomp/frame/writer.go new file mode 100644 index 0000000..7c6e83e --- /dev/null +++ b/backend/services/stomp/frame/writer.go @@ -0,0 +1,100 @@ +package frame + +import ( + "bufio" + "io" +) + +// slices used to write frames +var ( + colonSlice = []byte{58} // colon ':' + crlfSlice = []byte{13, 10} // CR-LF + newlineSlice = []byte{10} // newline (LF) + nullSlice = []byte{0} // null character +) + +// Writes STOMP frames to an underlying io.Writer. +type Writer struct { + writer *bufio.Writer +} + +// Creates a new Writer object, which writes to an underlying io.Writer. +func NewWriter(writer io.Writer) *Writer { + return NewWriterSize(writer, 4096) +} + +func NewWriterSize(writer io.Writer, bufferSize int) *Writer { + return &Writer{writer: bufio.NewWriterSize(writer, bufferSize)} +} + +// Write the contents of a frame to the underlying io.Writer. +func (w *Writer) Write(f *Frame) error { + var err error + + if f == nil { + // nil frame means send a heart-beat LF + _, err = w.writer.Write(newlineSlice) + if err != nil { + return err + } + } else { + _, err = w.writer.Write([]byte(f.Command)) + if err != nil { + return err + } + + _, err = w.writer.Write(newlineSlice) + if err != nil { + return err + } + + //println("TX:", f.Command) + if f.Header != nil { + for i := 0; i < f.Header.Len(); i++ { + key, value := f.Header.GetAt(i) + //println(" ", key, ":", value) + _, err = replacerForEncodeValue.WriteString(w.writer, key) + if err != nil { + return err + } + _, err = w.writer.Write(colonSlice) + if err != nil { + return err + } + _, err = replacerForEncodeValue.WriteString(w.writer, value) + if err != nil { + return err + } + _, err = w.writer.Write(newlineSlice) + if err != nil { + return err + } + } + } + + _, err = w.writer.Write(newlineSlice) + if err != nil { + return err + } + + if len(f.Body) > 0 { + _, err = w.writer.Write(f.Body) + if err != nil { + return err + } + } + + // write the final null (0) byte + _, err = w.writer.Write(nullSlice) + if err != nil { + return err + } + } + + err = w.writer.Flush() + if err != nil { + return err + } + + return nil +} diff --git a/backend/services/stomp/frame/writer_test.go b/backend/services/stomp/frame/writer_test.go new file mode 100644 index 0000000..ea5a1f5 --- /dev/null +++ b/backend/services/stomp/frame/writer_test.go @@ -0,0 +1,48 @@ +package frame + +import ( + "bytes" + "strings" + + . "gopkg.in/check.v1" +) + +type WriterSuite struct{} + +var _ = Suite(&WriterSuite{}) + +func (s *WriterSuite) TestWrites(c *C) { + var frameTexts = []string{ + "CONNECT\nlogin:xxx\npasscode:yyy\n\n\x00", + + "SEND\n" + + "destination:/queue/request\n" + + "tx:1\n" + + "content-length:5\n" + + "\n\x00\x01\x02\x03\x04\x00", + + "SEND\ndestination:x\n\nABCD\x00", + + "SEND\ndestination:x\ndodgy\\nheader\\c:abc\\n\\c\n\n123456\x00", + } + + for _, frameText := range frameTexts { + writeToBufferAndCheck(c, frameText) + } +} + +func writeToBufferAndCheck(c *C, frameText string) { + reader := NewReader(strings.NewReader(frameText)) + + frame, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(frame, NotNil) + + var b bytes.Buffer + var writer = NewWriter(&b) + err = writer.Write(frame) + c.Assert(err, IsNil) + newFrameText := b.String() + c.Check(newFrameText, Equals, frameText) + c.Check(b.String(), Equals, frameText) +} diff --git a/backend/services/stomp/go.mod b/backend/services/stomp/go.mod index 6c5dc46..13788ca 100644 --- a/backend/services/stomp/go.mod +++ b/backend/services/stomp/go.mod @@ -1,8 +1,9 @@ -module github.com/leandrofars/stomp +module github.com/go-stomp/stomp/v3 -go 1.21.3 +go 1.15 require ( - github.com/go-stomp/stomp v2.1.4+incompatible // indirect + github.com/golang/mock v1.6.0 github.com/joho/godotenv v1.5.1 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c ) diff --git a/backend/services/stomp/go.sum b/backend/services/stomp/go.sum index 0df3975..5d3b5df 100644 --- a/backend/services/stomp/go.sum +++ b/backend/services/stomp/go.sum @@ -1,4 +1,34 @@ -github.com/go-stomp/stomp v2.1.4+incompatible h1:D3SheUVDOz9RsjVWkoh/1iCOwD0qWjyeTZMUZ0EXg2Y= -github.com/go-stomp/stomp v2.1.4+incompatible/go.mod h1:VqCtqNZv1226A1/79yh+rMiFUcfY3R109np+7ke4n0c= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/backend/services/stomp/id.go b/backend/services/stomp/id.go new file mode 100644 index 0000000..cd84748 --- /dev/null +++ b/backend/services/stomp/id.go @@ -0,0 +1,17 @@ +package stomp + +import ( + "strconv" + "sync/atomic" +) + +var _lastId uint64 + +// allocateId returns a unique number for the current +// process. Starts at one and increases. Used for +// allocating subscription ids, receipt ids, +// transaction ids, etc. +func allocateId() string { + id := atomic.AddUint64(&_lastId, 1) + return strconv.FormatUint(id, 10) +} diff --git a/backend/services/stomp/id_test.go b/backend/services/stomp/id_test.go new file mode 100644 index 0000000..074ce6b --- /dev/null +++ b/backend/services/stomp/id_test.go @@ -0,0 +1,43 @@ +package stomp + +import ( + . "gopkg.in/check.v1" + "runtime" +) + +// only used during testing, does not need to be thread-safe +func resetId() { + _lastId = 0 +} + +func (s *StompSuite) SetUpSuite(c *C) { + resetId() + runtime.GOMAXPROCS(runtime.NumCPU()) +} + +func (s *StompSuite) TearDownSuite(c *C) { + runtime.GOMAXPROCS(1) +} + +func (s *StompSuite) TestAllocateId(c *C) { + c.Assert(allocateId(), Equals, "1") + c.Assert(allocateId(), Equals, "2") + + ch := make(chan bool, 50) + for i := 0; i < 50; i++ { + go doAllocate(100, ch) + } + + for i := 0; i < 50; i++ { + <-ch + } + + c.Assert(allocateId(), Equals, "5003") +} + +func doAllocate(count int, ch chan bool) { + for i := 0; i < count; i++ { + _ = allocateId() + } + ch <- true +} diff --git a/backend/services/stomp/internal/log/stdlogger.go b/backend/services/stomp/internal/log/stdlogger.go new file mode 100644 index 0000000..784c6f0 --- /dev/null +++ b/backend/services/stomp/internal/log/stdlogger.go @@ -0,0 +1,51 @@ +package log + +import ( + "fmt" + stdlog "log" +) + +var ( + debugPrefix = "DEBUG: " + infoPrefix = "INFO: " + warnPrefix = "WARN: " + errorPrefix = "ERROR: " +) + +func logf(prefix string, format string, value ...interface{}) { + _ = stdlog.Output(3, fmt.Sprintf(prefix+format+"\n", value...)) +} + +type StdLogger struct{} + +func (s StdLogger) Debugf(format string, value ...interface{}) { + logf(debugPrefix, format, value...) +} + +func (s StdLogger) Debug(message string) { + logf(debugPrefix, "%s", message) +} + +func (s StdLogger) Infof(format string, value ...interface{}) { + logf(infoPrefix, format, value...) +} + +func (s StdLogger) Info(message string) { + logf(infoPrefix, "%s", message) +} + +func (s StdLogger) Warningf(format string, value ...interface{}) { + logf(warnPrefix, format, value...) +} + +func (s StdLogger) Warning(message string) { + logf(warnPrefix, "%s", message) +} + +func (s StdLogger) Errorf(format string, value ...interface{}) { + logf(errorPrefix, format, value...) +} + +func (s StdLogger) Error(message string) { + logf(errorPrefix, "%s", message) +} diff --git a/backend/services/stomp/logger.go b/backend/services/stomp/logger.go new file mode 100644 index 0000000..4651d46 --- /dev/null +++ b/backend/services/stomp/logger.go @@ -0,0 +1,13 @@ +package stomp + +type Logger interface { + Debugf(format string, value ...interface{}) + Infof(format string, value ...interface{}) + Warningf(format string, value ...interface{}) + Errorf(format string, value ...interface{}) + + Debug(message string) + Info(message string) + Warning(message string) + Error(message string) +} diff --git a/backend/services/stomp/message.go b/backend/services/stomp/message.go new file mode 100644 index 0000000..e53f042 --- /dev/null +++ b/backend/services/stomp/message.go @@ -0,0 +1,68 @@ +package stomp + +import ( + "io" + "github.com/go-stomp/stomp/v3/frame" +) + +// A Message represents a message received from the STOMP server. +// In most cases a message corresponds to a single STOMP MESSAGE frame +// received from the STOMP server. If, however, the Err field is non-nil, +// then the message corresponds to a STOMP ERROR frame, or a connection +// error between the client and the server. +type Message struct { + // Indicates whether an error was received on the subscription. + // The error will contain details of the error. If the server + // sent an ERROR frame, then the Body, ContentType and Header fields + // will be populated according to the contents of the ERROR frame. + Err error + + // Destination the message has been sent to. + Destination string + + // MIME content type. + ContentType string // MIME content + + // Connection that the message was received on. + Conn *Conn + + // Subscription associated with the message. + Subscription *Subscription + + // Optional header entries. When received from the server, + // these are the header entries received with the message. + Header *frame.Header + + // The message body, which is an arbitrary sequence of bytes. + // The ContentType indicates the format of this body. + Body []byte // Content of message +} + +// ShouldAck returns true if this message should be acknowledged to +// the STOMP server that sent it. +func (msg *Message) ShouldAck() bool { + if msg.Subscription == nil { + // not received from the server, so no acknowledgement required + return false + } + + return msg.Subscription.AckMode() != AckAuto +} + +func (msg *Message) Read(p []byte) (int, error) { + if len(msg.Body) == 0 { + return 0, io.EOF + } + n := copy(p, msg.Body) + msg.Body = msg.Body[n:] + return n, nil +} + +func (msg *Message) ReadByte() (byte, error) { + if len(msg.Body) == 0 { + return 0, io.EOF + } + n := msg.Body[0] + msg.Body = msg.Body[1:] + return n, nil +} diff --git a/backend/services/stomp/send_options.go b/backend/services/stomp/send_options.go new file mode 100644 index 0000000..bd81b00 --- /dev/null +++ b/backend/services/stomp/send_options.go @@ -0,0 +1,55 @@ +package stomp + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// SendOpt contains options for for the Conn.Send and Transaction.Send functions. +var SendOpt struct { + // Receipt specifies that the client should request acknowledgement + // from the server before the send operation successfully completes. + Receipt func(*frame.Frame) error + + // NoContentLength specifies that the SEND frame should not include + // a content-length header entry. By default the content-length header + // entry is always included, but some message brokers assign special + // meaning to STOMP frames that do not contain a content-length + // header entry. (In particular ActiveMQ interprets STOMP frames + // with no content-length as being a text message) + NoContentLength func(*frame.Frame) error + + // Header provides the opportunity to include custom header entries + // in the SEND frame that the client sends to the server. This option + // can be specified multiple times if multiple custom header entries + // are required. + Header func(key, value string) func(*frame.Frame) error +} + +func init() { + SendOpt.Receipt = func(f *frame.Frame) error { + if f.Command != frame.SEND { + return ErrInvalidCommand + } + id := allocateId() + f.Header.Set(frame.Receipt, id) + return nil + } + + SendOpt.NoContentLength = func(f *frame.Frame) error { + if f.Command != frame.SEND { + return ErrInvalidCommand + } + f.Header.Del(frame.ContentLength) + return nil + } + + SendOpt.Header = func(key, value string) func(*frame.Frame) error { + return func(f *frame.Frame) error { + if f.Command != frame.SEND { + return ErrInvalidCommand + } + f.Header.Add(key, value) + return nil + } + } +} diff --git a/backend/services/stomp/server/client/channel_test.go b/backend/services/stomp/server/client/channel_test.go new file mode 100644 index 0000000..083a600 --- /dev/null +++ b/backend/services/stomp/server/client/channel_test.go @@ -0,0 +1,88 @@ +package client + +import ( + . "gopkg.in/check.v1" +) + +// Test suite for testing that channels work the way I expect. +type ChannelSuite struct{} + +var _ = Suite(&ChannelSuite{}) + +func (s *ChannelSuite) TestChannelWhenClosed(c *C) { + + ch := make(chan int, 10) + + ch <- 1 + ch <- 2 + + select { + case i, ok := <-ch: + c.Assert(i, Equals, 1) + c.Assert(ok, Equals, true) + default: + c.Error("expected value on channel") + } + + select { + case i := <-ch: + c.Assert(i, Equals, 2) + default: + c.Error("expected value on channel") + } + + select { + case _ = <-ch: + c.Error("not expecting anything on the channel") + default: + } + + ch <- 3 + close(ch) + + select { + case i := <-ch: + c.Assert(i, Equals, 3) + default: + c.Error("expected value on channel") + } + + select { + case _, ok := <-ch: + c.Assert(ok, Equals, false) + default: + c.Error("expected value on channel") + } + + select { + case _, ok := <-ch: + c.Assert(ok, Equals, false) + default: + c.Error("expected value on channel") + } +} + +func (s *ChannelSuite) TestMultipleChannels(c *C) { + + ch1 := make(chan int, 10) + ch2 := make(chan string, 10) + + ch1 <- 1 + + select { + case i, ok := <-ch1: + c.Assert(i, Equals, 1) + c.Assert(ok, Equals, true) + case _ = <-ch2: + default: + c.Error("expected value on channel") + } + + select { + case _ = <-ch1: + c.Error("not expected") + case _ = <-ch2: + c.Error("not expected") + default: + } +} diff --git a/backend/services/stomp/server/client/client.go b/backend/services/stomp/server/client/client.go new file mode 100644 index 0000000..33cf590 --- /dev/null +++ b/backend/services/stomp/server/client/client.go @@ -0,0 +1,7 @@ +/* +Package client implements client connectivity in the STOMP server. + +The key abstractions include a connection, a subscription and +a client request. +*/ +package client diff --git a/backend/services/stomp/server/client/client_test.go b/backend/services/stomp/server/client/client_test.go new file mode 100644 index 0000000..034e86d --- /dev/null +++ b/backend/services/stomp/server/client/client_test.go @@ -0,0 +1,12 @@ +package client + +import ( + "gopkg.in/check.v1" + "testing" +) + +// Runs all gocheck tests in this package. +// See other *_test.go files for gocheck tests. +func TestClient(t *testing.T) { + check.TestingT(t) +} diff --git a/backend/services/stomp/server/client/config.go b/backend/services/stomp/server/client/config.go new file mode 100644 index 0000000..1050090 --- /dev/null +++ b/backend/services/stomp/server/client/config.go @@ -0,0 +1,25 @@ +package client + +import ( + "time" + + "github.com/go-stomp/stomp/v3" +) + +// Contains information the client package needs from the +// rest of the STOMP server code. +type Config interface { + // Method to authenticate a login and associated passcode. + // Returns true if login/passcode is valid, false otherwise. + Authenticate(login, passcode string) bool + + // Default duration for read/write heart-beat values. If this + // returns zero, no heart-beat will take place. If this value is + // larger than the maximu permitted value (which is more than + // 11 days, but less than 12 days), then it is truncated to the + // maximum permitted values. + HeartBeat() time.Duration + + // Logger provides the logger for a client + Logger() stomp.Logger +} diff --git a/backend/services/stomp/server/client/conn.go b/backend/services/stomp/server/client/conn.go new file mode 100644 index 0000000..9fb7e45 --- /dev/null +++ b/backend/services/stomp/server/client/conn.go @@ -0,0 +1,781 @@ +package client + +import ( + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" +) + +// Maximum number of pending frames allowed to a client. +// before a disconnect occurs. If the client cannot keep +// up with the server, we do not want the server to backlog +// pending frames indefinitely. +const maxPendingWrites = 16 + +// Maximum number of pending frames allowed before the read +// go routine starts blocking. +const maxPendingReads = 16 + +// Represents a connection with the STOMP client. +type Conn struct { + config Config + rw net.Conn // Network connection to client + writer *frame.Writer // Writes STOMP frames directly to the network connection + requestChannel chan Request // For sending requests to upper layer + subChannel chan *Subscription // Receives subscription messages for client + writeChannel chan *frame.Frame // Receives unacknowledged (topic) messages for client + readChannel chan *frame.Frame // Receives frames from the client + stateFunc func(c *Conn, f *frame.Frame) error // State processing function + writeTimeout time.Duration // Heart beat write timeout + version stomp.Version // Negotiated STOMP protocol version + closed bool // Is the connection closed + txStore *txStore // Stores transactions in progress + lastMsgId uint64 // last message-id value + subList *SubscriptionList // List of subscriptions requiring acknowledgement + subs map[string]*Subscription // All subscriptions, keyed by id + validator stomp.Validator // For validating STOMP frames + log stomp.Logger +} + +// Creates a new client connection. The config parameter contains +// process-wide configuration parameters relevant to a client connection. +// The rw parameter is a network connection object for communicating with +// the client. All client requests are sent via the ch channel to the +// upper layer. +func NewConn(config Config, rw net.Conn, ch chan Request) *Conn { + c := &Conn{ + config: config, + rw: rw, + requestChannel: ch, + subChannel: make(chan *Subscription, maxPendingWrites), + writeChannel: make(chan *frame.Frame, maxPendingWrites), + readChannel: make(chan *frame.Frame, maxPendingReads), + txStore: &txStore{}, + subList: NewSubscriptionList(), + subs: make(map[string]*Subscription), + log: config.Logger(), + } + go c.readLoop() + go c.processLoop() + return c +} + +// Write a frame to the connection without requiring +// any acknowledgement. +func (c *Conn) Send(f *frame.Frame) { + // Place the frame on the write channel. If the + // write channel is full, the caller will block. + c.writeChannel <- f +} + +// Send and ERROR message to the client. The client +// connection will disconnect as soon as the ERROR +// message has been transmitted. The message header +// will be based on the contents of the err parameter. +func (c *Conn) SendError(err error) { + f := frame.New(frame.ERROR, frame.Message, err.Error()) + c.Send(f) // will close after successful send +} + +// Send an ERROR frame to the client and immediately. The error +// message is derived from err. If f is non-nil, it is the frame +// whose contents have caused the error. Include the receipt-id +// header if the frame contains a receipt header. +func (c *Conn) sendErrorImmediately(err error, f *frame.Frame) { + errorFrame := frame.New(frame.ERROR, + frame.Message, err.Error()) + + // Include a receipt-id header if the frame that prompted the error had + // a receipt header (as suggested by the STOMP protocol spec). + if f != nil { + if receipt, ok := f.Header.Contains(frame.Receipt); ok { + errorFrame.Header.Add(frame.ReceiptId, receipt) + } + } + + // send the frame to the client, ignore any error condition + // because we are about to close the connection anyway + _ = c.sendImmediately(errorFrame) +} + +// Sends a STOMP frame to the client immediately, does not push onto the +// write channel to be processed in turn. +func (c *Conn) sendImmediately(f *frame.Frame) error { + return c.writer.Write(f) +} + +// Go routine for reading bytes from a client and assembling into +// STOMP frames. Also handles heart-beat read timeout. All read +// frames are pushed onto the read channel to be processed by the +// processLoop go-routine. This keeps all processing of frames for +// this connection on the one go-routine and avoids race conditions. +func (c *Conn) readLoop() { + reader := frame.NewReader(c.rw) + expectingConnect := true + readTimeout := time.Duration(0) + for { + if readTimeout == time.Duration(0) { + // infinite timeout + c.rw.SetReadDeadline(time.Time{}) + } else { + c.rw.SetReadDeadline(time.Now().Add(readTimeout * 2)) + } + f, err := reader.Read() + if err != nil { + if err == io.EOF { + c.log.Errorf("connection closed: %s", c.rw.RemoteAddr()) + } else { + c.log.Errorf("read failed: %v : %s", err, c.rw.RemoteAddr()) + } + + // Close the read channel so that the processing loop will + // know to terminate, if it has not already done so. This is + // the only channel that we close, because it is the only one + // we know who is writing to. + close(c.readChannel) + return + } + + if f == nil { + // if the frame is nil, then it is a heartbeat + continue + } + + // If we are expecting a CONNECT or STOMP command, extract + // the heart-beat header and work out the read timeout. + // Note that the processing loop will duplicate this to + // some extent, but letting this go-routine work out its own + // read timeout means no synchronization is necessary. + if expectingConnect { + // Expecting a CONNECT or STOMP command, get the heart-beat + cx, _, err := getHeartBeat(f) + + // Ignore the error condition and treat as no read timeout. + // The processing loop will handle the error again and + // process correctly. + if err == nil { + // Minimum value as per server config. If the client + // has requested shorter periods than this value, the + // server will insist on the longer time period. + min := asMilliseconds(c.config.HeartBeat(), maxHeartBeat) + + // apply a minimum heartbeat + if cx > 0 && cx < min { + cx = min + } + + readTimeout = time.Duration(cx) * time.Millisecond + + expectingConnect = false + } + } + + // Add the frame to the read channel. Note that this will block + // if we are reading from the client quicker than the server + // can process frames. + c.readChannel <- f + } +} + +// Go routine that processes all read frames and all write frames. +// Having all processing in one go routine helps eliminate any race conditions. +func (c *Conn) processLoop() { + defer c.cleanupConn() + + c.writer = frame.NewWriter(c.rw) + c.stateFunc = connecting + + var timerChannel <-chan time.Time + var timer *time.Timer + for { + if c.writeTimeout > 0 && timer == nil { + timer = time.NewTimer(c.writeTimeout) + timerChannel = timer.C + } + + select { + case f, ok := <-c.writeChannel: + if !ok { + // write channel has been closed, so + // exit go-routine (after cleaning up) + return + } + + // have a frame to the client with + // no acknowledgement required (topic) + + // stop the heart-beat timer + if timer != nil { + timer.Stop() + timer = nil + } + + c.allocateMessageId(f, nil) + + // write the frame to the client + err := c.writer.Write(f) + if err != nil { + // if there is an error writing to + // the client, there is not much + // point trying to send an ERROR frame, + // so just exit go-routine (after cleaning up) + return + } + + // if the frame just sent to the client is an error + // frame, we disconnect + if f.Command == frame.ERROR { + // sent an ERROR frame, so disconnect + return + } + + case f, ok := <-c.readChannel: + if !ok { + // read channel has been closed, so + // exit go-routine (after cleaning up) + return + } + + // Just received a frame from the client. + // Validate the frame, checking for mandatory + // headers and prohibited headers. + if c.validator != nil { + err := c.validator.Validate(f) + if err != nil { + c.log.Warningf("validation failed for %s frame: %v", f.Command, err) + c.sendErrorImmediately(err, f) + return + } + } + + // Pass to the appropriate function for handling + // according to the current state of the connection. + err := c.stateFunc(c, f) + if err != nil { + c.sendErrorImmediately(err, f) + return + } + + case sub, ok := <-c.subChannel: + if !ok { + // subscription channel has been closed, + // so exit go-routine (after cleaning up) + return + } + + // have a frame to the client which requires + // acknowledgement to the upper layer + + // stop the heart-beat timer + if timer != nil { + timer.Stop() + timer = nil + } + + // there is the possibility that the subscription + // has been unsubscribed just prior to receiving + // this, so we check + if _, ok = c.subs[sub.id]; ok { + // allocate a message-id, note that the + // subscription id has already been set + c.allocateMessageId(sub.frame, sub) + + // write the frame to the client + err := c.writer.Write(sub.frame) + if err != nil { + // if there is an error writing to + // the client, there is not much + // point trying to send an ERROR frame, + // so just exit go-routine (after cleaning up) + return + } + + if sub.ack == frame.AckAuto { + // subscription does not require acknowledgement, + // so send the subscription back the upper layer + // straight away + sub.frame = nil + c.requestChannel <- Request{Op: SubscribeOp, Sub: sub} + } else { + // subscription requires acknowledgement + c.subList.Add(sub) + } + } else { + // Subscription no longer exists, requeue + c.requestChannel <- Request{Op: RequeueOp, Frame: sub.frame} + } + + case _ = <-timerChannel: + // stop the heart-beat timer + if timer != nil { + timer.Stop() + timer = nil + } + // write a heart-beat + err := c.writer.Write(nil) + if err != nil { + return + } + } + } +} + +// Called when the connection is closing, and takes care of +// unsubscribing all subscriptions with the upper layer, and +// re-queueing all unacknowledged messages to the upper layer. +func (c *Conn) cleanupConn() { + // clean up any pending transactions + c.txStore.Init() + + c.discardWriteChannelFrames() + + // Unsubscribe every subscription known to the upper layer. + // This should be done before cleaning up the subscription + // channel. If we requeued messages before doing this, + // we might end up getting them back again. + for _, sub := range c.subs { + // Note that we only really need to send a request if the + // subscription does not have a frame, but for simplicity + // all subscriptions are unsubscribed from the upper layer. + c.requestChannel <- Request{Op: UnsubscribeOp, Sub: sub} + } + + // Clear out the map of subscriptions + c.subs = nil + + // Every subscription requiring acknowledgement has a frame + // that needs to be requeued in the upper layer + for sub := c.subList.Get(); sub != nil; sub = c.subList.Get() { + c.requestChannel <- Request{Op: RequeueOp, Frame: sub.frame} + } + + // empty the subscription and write queue + c.discardWriteChannelFrames() + c.cleanupSubChannel() + + // Tell the upper layer we are now disconnected + c.requestChannel <- Request{Op: DisconnectedOp, Conn: c} + + // empty the subscription and write queue one more time + c.discardWriteChannelFrames() + c.cleanupSubChannel() + + // Should not hurt to call this if it is already closed? + c.rw.Close() +} + +// Discard anything on the write channel. These frames +// do not get acknowledged, and are either topic MESSAGE +// frames or ERROR frames. +func (c *Conn) discardWriteChannelFrames() { + for finished := false; !finished; { + select { + case _, ok := <-c.writeChannel: + if !ok { + finished = true + } + + default: + finished = true + } + } +} + +func (c *Conn) cleanupSubChannel() { + // Read the subscription channel until it is empty. + // Each frame should be requeued to the upper layer. + for finished := false; !finished; { + select { + case sub, ok := <-c.subChannel: + if !ok { + finished = true + } else { + c.requestChannel <- Request{Op: RequeueOp, Frame: sub.frame} + } + + default: + finished = true + } + } +} + +// Send a frame to the client, allocating necessary headers prior. +func (c *Conn) allocateMessageId(f *frame.Frame, sub *Subscription) { + if f.Command == frame.MESSAGE || f.Command == frame.ACK { + // allocate the value of message-id for this frame + c.lastMsgId++ + messageId := strconv.FormatUint(c.lastMsgId, 10) + f.Header.Set(frame.MessageId, messageId) + f.Header.Set(frame.Id, messageId) + + // if there is any requirement by the client to acknowledge, set + // the ack header as per STOMP 1.2 + if sub == nil || sub.ack == frame.AckAuto { + f.Header.Del(frame.Ack) + } else { + f.Header.Set(frame.Ack, messageId) + } + } +} + +// State function for expecting connect frame. +func connecting(c *Conn, f *frame.Frame) error { + switch f.Command { + case frame.CONNECT, frame.STOMP: + return c.handleConnect(f) + } + return notConnected +} + +// State function for after connect frame received. +func connected(c *Conn, f *frame.Frame) error { + switch f.Command { + case frame.CONNECT, frame.STOMP: + return unexpectedCommand + case frame.DISCONNECT: + return c.handleDisconnect(f) + case frame.BEGIN: + return c.handleBegin(f) + case frame.ABORT: + return c.handleAbort(f) + case frame.COMMIT: + return c.handleCommit(f) + case frame.SEND: + return c.handleSend(f) + case frame.SUBSCRIBE: + return c.handleSubscribe(f) + case frame.UNSUBSCRIBE: + return c.handleUnsubscribe(f) + case frame.ACK: + return c.handleAck(f) + case frame.NACK: + return c.handleNack(f) + case frame.MESSAGE, frame.RECEIPT, frame.ERROR: + // should only be sent by the server, should not come from the client + return unexpectedCommand + } + return unknownCommand +} + +func (c *Conn) handleConnect(f *frame.Frame) error { + var err error + + if _, ok := f.Header.Contains(frame.Receipt); ok { + // CONNNECT and STOMP frames are not allowed to have + // a receipt header. + return receiptInConnect + } + + // if either of these fields are absent, pass nil to the + // authenticator function. + login, _ := f.Header.Contains(frame.Login) + passcode, _ := f.Header.Contains(frame.Passcode) + if !c.config.Authenticate(login, passcode) { + // sleep to slow down a rogue client a little bit + c.log.Error("authentication failed") + time.Sleep(time.Second) + return authenticationFailed + } + + c.version, err = determineVersion(f) + if err != nil { + c.log.Error("protocol version negotiation failed") + return err + } + c.validator = stomp.NewValidator(c.version) + + if c.version == stomp.V10 { + // don't want to handle V1.0 at the moment + // TODO: get working for V1.0 + c.log.Errorf("unsupported version %s", c.version) + return unsupportedVersion + } + + cx, cy, err := getHeartBeat(f) + if err != nil { + c.log.Error("invalid heart-beat") + return err + } + + // Minimum value as per server config. If the client + // has requested shorter periods than this value, the + // server will insist on the longer time period. + min := asMilliseconds(c.config.HeartBeat(), maxHeartBeat) + + // apply a minimum heartbeat + if cx > 0 && cx < min { + cx = min + } + if cy > 0 && cy < min { + cy = min + } + + // the read timeout has already been processed in the readLoop + // go-routine + c.writeTimeout = time.Duration(cy) * time.Millisecond + + /* TR-369 section 4.4.1.1 [Connecting a USP Endpoint to the STOMP Server] */ + /* + R-STOMP.4: USP Endpoints sending a STOMP frame MUST include (in addition to other + mandatory STOMP headers) an endpoint-id STOMP header containing the + Endpoint ID of the USP Endpoint sending the frame. + */ + endpointId := f.Header.Get("endpoint-id") + + response := frame.New(frame.CONNECTED, + frame.Version, string(c.version), + frame.Server, "stompd/x.y.z", // TODO: get version + frame.HeartBeat, fmt.Sprintf("%d,%d", cy, cx), + frame.SubscribeDest, "oktopus/v1/agent/"+endpointId, + ) + + c.sendImmediately(response) + c.stateFunc = connected + + // tell the upper layer we are connected + c.requestChannel <- Request{Op: ConnectedOp, Conn: c} + + return nil +} + +// Sends a RECEIPT frame to the client if the frame f contains +// a receipt header. If the frame does contain a receipt header, +// it will be removed from the frame. +func (c *Conn) sendReceiptImmediately(f *frame.Frame) error { + if receipt, ok := f.Header.Contains(frame.Receipt); ok { + // Remove the receipt header from the frame. This is handy + // for transactions, because the frame has its receipt + // header removed prior to entering the transaction store. + // When the frame is processed upon transaction commit, it + // will not have a receipt header anymore. + f.Header.Del(frame.Receipt) + return c.sendImmediately(frame.New(frame.RECEIPT, + frame.ReceiptId, receipt)) + } + return nil +} + +func (c *Conn) handleDisconnect(f *frame.Frame) error { + // As soon as we receive a DISCONNECT frame from a client, we do + // not want to send any more frames to that client, with the exception + // of a RECEIPT frame if the client has requested one. + // Ignore the error condition if we cannot send a RECEIPT frame, + // as the connection is about to close anyway. + _ = c.sendReceiptImmediately(f) + return nil +} + +func (c *Conn) handleBegin(f *frame.Frame) error { + // the frame should already have been validated for the + // transaction header, but we check again here. + if transaction, ok := f.Header.Contains(frame.Transaction); ok { + // Send a receipt and remove the header + err := c.sendReceiptImmediately(f) + if err != nil { + return err + } + + return c.txStore.Begin(transaction) + } + return missingHeader(frame.Transaction) +} + +func (c *Conn) handleCommit(f *frame.Frame) error { + // the frame should already have been validated for the + // transaction header, but we check again here. + if transaction, ok := f.Header.Contains(frame.Transaction); ok { + // Send a receipt and remove the header + err := c.sendReceiptImmediately(f) + if err != nil { + return err + } + return c.txStore.Commit(transaction, func(f *frame.Frame) error { + // Call the state function (again) for each frame in the + // transaction. This time each frame is stripped of its transaction + // header (and its receipt header as well, if it had one). + return c.stateFunc(c, f) + }) + } + return missingHeader(frame.Transaction) +} + +func (c *Conn) handleAbort(f *frame.Frame) error { + // the frame should already have been validated for the + // transaction header, but we check again here. + if transaction, ok := f.Header.Contains(frame.Transaction); ok { + // Send a receipt and remove the header + err := c.sendReceiptImmediately(f) + if err != nil { + return err + } + return c.txStore.Abort(transaction) + } + return missingHeader(frame.Transaction) +} + +func (c *Conn) handleSubscribe(f *frame.Frame) error { + id, ok := f.Header.Contains(frame.Id) + if !ok { + return missingHeader(frame.Id) + } + + dest, ok := f.Header.Contains(frame.Destination) + if !ok { + return missingHeader(frame.Destination) + } + + ack, ok := f.Header.Contains(frame.Ack) + if !ok { + ack = frame.AckAuto + } + + sub, ok := c.subs[id] + if ok { + return subscriptionExists + } + + sub = newSubscription(c, dest, id, ack) + c.subs[id] = sub + + // send information about new subscription to upper layer + c.requestChannel <- Request{Op: SubscribeOp, Sub: sub} + return nil +} + +func (c *Conn) handleUnsubscribe(f *frame.Frame) error { + id, ok := f.Header.Contains(frame.Id) + if !ok { + return missingHeader(frame.Id) + } + + sub, ok := c.subs[id] + if !ok { + return subscriptionNotFound + } + + // remove the subscription + delete(c.subs, id) + + // tell the upper layer of the unsubscribe + c.requestChannel <- Request{Op: UnsubscribeOp, Sub: sub} + return nil +} + +func (c *Conn) handleAck(f *frame.Frame) error { + var err error + var msgId string + + if ack, ok := f.Header.Contains(frame.Ack); ok { + msgId = ack + } else if msgId, ok = f.Header.Contains(frame.MessageId); !ok { + return missingHeader(frame.MessageId) + } + + // expecting message id to be a uint64 + msgId64, err := strconv.ParseUint(msgId, 10, 64) + if err != nil { + return err + } + + // Send a receipt and remove the header + err = c.sendReceiptImmediately(f) + if err != nil { + return err + } + + if tx, ok := f.Header.Contains(frame.Transaction); ok { + // the transaction header is removed from the frame + err = c.txStore.Add(tx, f) + if err != nil { + return err + } + } else { + // handle any subscriptions that are acknowledged by this msg + c.subList.Ack(msgId64, func(s *Subscription) { + // remove frame from the subscription, it has been delivered + s.frame = nil + + // let the upper layer know that this subscription + // is ready for another frame + c.requestChannel <- Request{Op: SubscribeOp, Sub: s} + }) + } + + return nil +} + +func (c *Conn) handleNack(f *frame.Frame) error { + var err error + var msgId string + + if ack, ok := f.Header.Contains(frame.Ack); ok { + msgId = ack + } else if msgId, ok = f.Header.Contains(frame.MessageId); !ok { + return missingHeader(frame.MessageId) + } + + // expecting message id to be a uint64 + msgId64, err := strconv.ParseUint(msgId, 10, 64) + if err != nil { + return err + } + + // Send a receipt and remove the header + err = c.sendReceiptImmediately(f) + if err != nil { + return err + } + + if tx, ok := f.Header.Contains(frame.Transaction); ok { + // the transaction header is removed from the frame + err = c.txStore.Add(tx, f) + if err != nil { + return err + } + } else { + // handle any subscriptions that are acknowledged by this msg + c.subList.Nack(msgId64, func(s *Subscription) { + // send frame back to upper layer for requeue + c.requestChannel <- Request{Op: RequeueOp, Frame: s.frame} + + // remove frame from the subscription, it has been requeued + s.frame = nil + + // let the upper layer know that this subscription + // is ready for another frame + c.requestChannel <- Request{Op: SubscribeOp, Sub: s} + }) + } + return nil +} + +// Handle a SEND frame received from the client. Note that +// this method is called after a SEND message is received, +// but also after a transaction commit. +func (c *Conn) handleSend(f *frame.Frame) error { + // Send a receipt and remove the header + err := c.sendReceiptImmediately(f) + if err != nil { + return err + } + + if tx, ok := f.Header.Contains(frame.Transaction); ok { + // the transaction header is removed from the frame + err = c.txStore.Add(tx, f) + if err != nil { + return err + } + } else { + // not in a transaction + // change from SEND to MESSAGE + f.Command = frame.MESSAGE + c.requestChannel <- Request{Op: EnqueueOp, Frame: f} + } + + return nil +} diff --git a/backend/services/stomp/server/client/errors.go b/backend/services/stomp/server/client/errors.go new file mode 100644 index 0000000..0488c37 --- /dev/null +++ b/backend/services/stomp/server/client/errors.go @@ -0,0 +1,36 @@ +package client + +const ( + notConnected = errorMessage("expected CONNECT or STOMP frame") + unexpectedCommand = errorMessage("unexpected frame command") + unknownCommand = errorMessage("unknown command") + receiptInConnect = errorMessage("receipt header prohibited in CONNECT or STOMP frame") + authenticationFailed = errorMessage("authentication failed") + txAlreadyInProgress = errorMessage("transaction already in progress") + txUnknown = errorMessage("unknown transaction") + unsupportedVersion = errorMessage("unsupported version") + subscriptionExists = errorMessage("subscription already exists") + subscriptionNotFound = errorMessage("subscription not found") + invalidFrameFormat = errorMessage("invalid frame format") + invalidCommand = errorMessage("invalid command") + unknownVersion = errorMessage("incompatible version") + notConnectFrame = errorMessage("operation valid for STOMP and CONNECT frames only") + invalidHeartBeat = errorMessage("invalid format for heart-beat") + invalidOperationForFrame = errorMessage("invalid operation for frame") + exceededMaxFrameSize = errorMessage("exceeded max frame size") + invalidHeaderValue = errorMessage("invalid header value") +) + +type errorMessage string + +func (e errorMessage) Error() string { + return string(e) +} + +func missingHeader(name string) errorMessage { + return errorMessage("missing header: " + name) +} + +func prohibitedHeader(name string) errorMessage { + return errorMessage("prohibited header: " + name) +} diff --git a/backend/services/stomp/server/client/frame.go b/backend/services/stomp/server/client/frame.go new file mode 100644 index 0000000..ce967b6 --- /dev/null +++ b/backend/services/stomp/server/client/frame.go @@ -0,0 +1,119 @@ +package client + +import ( + "regexp" + "sort" + "strconv" + "strings" + + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" +) + +const ( + // Maximum permitted heart-beat timeout, about 11.5 days. + // Any client CONNECT frame with a larger value than this + // will be rejected. + maxHeartBeat = 999999999 +) + +var ( + // Regexp for heart-beat header value + heartBeatRegexp = regexp.MustCompile("^[0-9]{1,9},[0-9]{1,9}$") +) + +// Determine the most acceptable version based on the accept-version +// header of a CONNECT or STOMP frame. +// +// Returns stomp.V10 if a CONNECT frame and the accept-version header is missing. +// +// Returns an error if the frame is not a CONNECT or STOMP frame, or +// if the accept-header is malformed or does not contain any compatible +// version numbers. Also returns an error if the accept-header is missing +// for a STOMP frame. +// +// Otherwise, returns the highest compatible version specified in the +// accept-version header. Compatible versions are V1_0, V1_1 and V1_2. +func determineVersion(f *frame.Frame) (version stomp.Version, err error) { + // frame can be CONNECT or STOMP with slightly different + // handling of accept-verion for each + isConnect := f.Command == frame.CONNECT + + if !isConnect && f.Command != frame.STOMP { + err = notConnectFrame + return + } + + // start with an error, and remove if successful + err = unknownVersion + + if acceptVersion, ok := f.Header.Contains(frame.AcceptVersion); ok { + // sort the versions so that the latest version comes last + versions := strings.Split(acceptVersion, ",") + sort.Strings(versions) + for _, v := range versions { + switch stomp.Version(v) { + case stomp.V10: + version = stomp.V10 + err = nil + case stomp.V11: + version = stomp.V11 + err = nil + case stomp.V12: + version = stomp.V12 + err = nil + } + } + } else { + // CONNECT frames can be missing the accept-version header, + // we assume V1.0 in this case. STOMP frames were introduced + // in V1.1, so they must have an accept-version header. + if isConnect { + // no "accept-version" header, so we assume 1.0 + version = stomp.V10 + err = nil + } else { + err = missingHeader(frame.AcceptVersion) + } + } + return +} + +// Determine the heart-beat values in a CONNECT or STOMP frame. +// +// Returns 0,0 if the heart-beat header is missing. Otherwise +// returns the cx and cy values in the frame. +// +// Returns an error if the heart-beat header is malformed, or if +// the frame is not a CONNECT or STOMP frame. In this implementation, +// a heart-beat header is considered malformed if either cx or cy +// is greater than MaxHeartBeat. +func getHeartBeat(f *frame.Frame) (cx, cy int, err error) { + if f.Command != frame.CONNECT && + f.Command != frame.STOMP && + f.Command != frame.CONNECTED { + err = invalidOperationForFrame + return + } + if heartBeat, ok := f.Header.Contains(frame.HeartBeat); ok { + if !heartBeatRegexp.MatchString(heartBeat) { + err = invalidHeartBeat + return + } + + // no error checking here because we are confident + // that everything will work because the regexp matches. + slice := strings.Split(heartBeat, ",") + value1, _ := strconv.ParseUint(slice[0], 10, 32) + value2, _ := strconv.ParseUint(slice[1], 10, 32) + cx = int(value1) + cy = int(value2) + } else { + // heart-beat header not present + // this else clause is not necessary, but + // included for clarity. + cx = 0 + cy = 0 + } + return +} diff --git a/backend/services/stomp/server/client/frame_test.go b/backend/services/stomp/server/client/frame_test.go new file mode 100644 index 0000000..508e6cb --- /dev/null +++ b/backend/services/stomp/server/client/frame_test.go @@ -0,0 +1,82 @@ +package client + +import ( + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" + . "gopkg.in/check.v1" +) + +type FrameSuite struct{} + +var _ = Suite(&FrameSuite{}) + +func (s *FrameSuite) TestDetermineVersion_V10_Connect(c *C) { + f := frame.New(frame.CONNECT) + version, err := determineVersion(f) + c.Check(err, IsNil) + c.Check(version, Equals, stomp.V10) +} + +func (s *FrameSuite) TestDetermineVersion_V10_Stomp(c *C) { + // the "STOMP" command was introduced in V1.1, so it must + // have an accept-version header + f := frame.New(frame.STOMP) + _, err := determineVersion(f) + c.Check(err, Equals, missingHeader(frame.AcceptVersion)) +} + +func (s *FrameSuite) TestDetermineVersion_V11_Connect(c *C) { + f := frame.New(frame.CONNECT) + f.Header.Add(frame.AcceptVersion, "1.1") + version, err := determineVersion(f) + c.Check(version, Equals, stomp.V11) + c.Check(err, IsNil) +} + +func (s *FrameSuite) TestDetermineVersion_MultipleVersions(c *C) { + f := frame.New(frame.CONNECT) + f.Header.Add(frame.AcceptVersion, "1.2,1.1,1.0,2.0") + version, err := determineVersion(f) + c.Check(version, Equals, stomp.V12) + c.Check(err, IsNil) +} + +func (s *FrameSuite) TestDetermineVersion_IncompatibleVersions(c *C) { + f := frame.New(frame.CONNECT) + f.Header.Add(frame.AcceptVersion, "0.2,0.1,1.3,2.0") + version, err := determineVersion(f) + c.Check(version, Equals, stomp.Version("")) + c.Check(err, Equals, unknownVersion) +} + +func (s *FrameSuite) TestHeartBeat(c *C) { + f := frame.New(frame.CONNECT, + frame.AcceptVersion, "1.2", + frame.Host, "XX") + + // no heart-beat header means zero values + x, y, err := getHeartBeat(f) + c.Check(x, Equals, 0) + c.Check(y, Equals, 0) + c.Check(err, IsNil) + + f.Header.Add("heart-beat", "123,456") + x, y, err = getHeartBeat(f) + c.Check(x, Equals, 123) + c.Check(y, Equals, 456) + c.Check(err, IsNil) + + f.Header.Set(frame.HeartBeat, "invalid") + x, y, err = getHeartBeat(f) + c.Check(x, Equals, 0) + c.Check(y, Equals, 0) + c.Check(err, Equals, invalidHeartBeat) + + f.Header.Del(frame.HeartBeat) + _, _, err = getHeartBeat(f) + c.Check(err, IsNil) + + f.Command = frame.SEND + _, _, err = getHeartBeat(f) + c.Check(err, Equals, invalidOperationForFrame) +} diff --git a/backend/services/stomp/server/client/request.go b/backend/services/stomp/server/client/request.go new file mode 100644 index 0000000..3b6aa40 --- /dev/null +++ b/backend/services/stomp/server/client/request.go @@ -0,0 +1,32 @@ +package client + +import ( + "strconv" + + "github.com/go-stomp/stomp/v3/frame" +) + +// Opcode used in client requests. +type RequestOp int + +func (r RequestOp) String() string { + return strconv.Itoa(int(r)) +} + +// Valid value for client request opcodes. +const ( + SubscribeOp RequestOp = iota // subscription ready + UnsubscribeOp // subscription not ready + EnqueueOp // send a message to a queue + RequeueOp // re-queue a message, not successfully sent + ConnectedOp // connection established + DisconnectedOp // connection disconnected +) + +// Client requests received to be processed by main processing loop +type Request struct { + Op RequestOp // opcode for request + Sub *Subscription // SubscribeOp, UnsubscribeOp + Frame *frame.Frame // EnqueueOp, RequeueOp + Conn *Conn // ConnectedOp, DisconnectedOp +} diff --git a/backend/services/stomp/server/client/subscription.go b/backend/services/stomp/server/client/subscription.go new file mode 100644 index 0000000..a9f1757 --- /dev/null +++ b/backend/services/stomp/server/client/subscription.go @@ -0,0 +1,84 @@ +package client + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +type Subscription struct { + conn *Conn + dest string + id string // client's subscription id + ack string // auto, client, client-individual + msgId uint64 // message-id (or ack) for acknowledgement + subList *SubscriptionList // am I in a list + frame *frame.Frame // message allocated to subscription +} + +func newSubscription(c *Conn, dest string, id string, ack string) *Subscription { + return &Subscription{ + conn: c, + dest: dest, + id: id, + ack: ack, + } +} + +func (s *Subscription) Destination() string { + return s.dest +} + +func (s *Subscription) Ack() string { + return s.ack +} + +func (s *Subscription) Id() string { + return s.id +} + +func (s *Subscription) IsAckedBy(msgId uint64) bool { + switch s.ack { + case frame.AckAuto: + return true + case frame.AckClient: + // any later message acknowledges an earlier message + return msgId >= s.msgId + case frame.AckClientIndividual: + return msgId == s.msgId + } + + // should not get here + panic("invalid value for subscript.ack") +} + +func (s *Subscription) IsNackedBy(msgId uint64) bool { + // TODO: not sure about this, interpreting NACK + // to apply to an individual message + return msgId == s.msgId +} + +func (s *Subscription) SendQueueFrame(f *frame.Frame) { + s.setSubscriptionHeader(f) + s.frame = f + + // let the connection deal with the subscription + // acknowledgement + s.conn.subChannel <- s +} + +// Send a message frame to the client, as part of this +// subscription. Called within the queue when a message +// frame is available. +func (s *Subscription) SendTopicFrame(f *frame.Frame) { + s.setSubscriptionHeader(f) + + // topics are handled differently, they just go + // straight to the client without acknowledgement + s.conn.writeChannel <- f +} + +func (s *Subscription) setSubscriptionHeader(f *frame.Frame) { + if s.frame != nil { + panic("subscription already has a frame pending") + } + f.Header.Set(frame.Subscription, s.id) +} diff --git a/backend/services/stomp/server/client/subscription_list.go b/backend/services/stomp/server/client/subscription_list.go new file mode 100644 index 0000000..5630517 --- /dev/null +++ b/backend/services/stomp/server/client/subscription_list.go @@ -0,0 +1,107 @@ +package client + +import ( + "container/list" +) + +// Maintains a list of subscriptions. Not thread-safe. +type SubscriptionList struct { + // TODO: implement linked list locally, adding next and prev + // pointers to the Subscription struct itself. + subs *list.List +} + +func NewSubscriptionList() *SubscriptionList { + return &SubscriptionList{list.New()} +} + +// Add a subscription to the back of the list. Will panic if +// the subscription destination does not match the subscription +// list destination. Will also panic if the subscription has already +// been added to a subscription list. +func (sl *SubscriptionList) Add(sub *Subscription) { + if sub.subList != nil { + panic("subscription is already in a subscription list") + } + sl.subs.PushBack(sub) + sub.subList = sl +} + +// Gets the first subscription in the list, or nil if there +// are no subscriptions available. The subscription is removed +// from the list. +func (sl *SubscriptionList) Get() *Subscription { + if sl.subs.Len() == 0 { + return nil + } + front := sl.subs.Front() + sub := front.Value.(*Subscription) + sl.subs.Remove(front) + sub.subList = nil + return sub +} + +// Removes the subscription from the list. +func (sl *SubscriptionList) Remove(s *Subscription) { + for e := sl.subs.Front(); e != nil; e = e.Next() { + if e.Value.(*Subscription) == s { + sl.subs.Remove(e) + s.subList = nil + return + } + } +} + +// Search for a subscription with the specified id and remove it. +// Returns a pointer to the subscription if found, nil otherwise. +func (sl *SubscriptionList) FindByIdAndRemove(id string) *Subscription { + for e := sl.subs.Front(); e != nil; e = e.Next() { + sub := e.Value.(*Subscription) + if sub.id == id { + sl.subs.Remove(e) + return sub + } + } + return nil +} + +// Finds all subscriptions in the subscription list that are acked by the +// specified message-id (or ack) header. The subscription is removed from +// the list and the callback function called for that subscription. +func (sl *SubscriptionList) Ack(msgId uint64, callback func(s *Subscription)) { + for e := sl.subs.Front(); e != nil; { + next := e.Next() + sub := e.Value.(*Subscription) + if sub.IsAckedBy(msgId) { + sl.subs.Remove(e) + callback(sub) + } + e = next + } +} + +// Finds all subscriptions in the subscription list that are *nacked* by the +// specified message-id (or ack) header. The subscription is removed from +// the list and the callback function called for that subscription. Current +// understanding that all NACKs are individual, but not sure +func (sl *SubscriptionList) Nack(msgId uint64, callback func(s *Subscription)) { + for e := sl.subs.Front(); e != nil; { + next := e.Next() + sub := e.Value.(*Subscription) + if sub.IsNackedBy(msgId) { + sl.subs.Remove(e) + callback(sub) + } + e = next + } +} + +// Invoke a callback function for every subscription in the list. +func (sl *SubscriptionList) ForEach(callback func(s *Subscription, isLast bool)) { + for e := sl.subs.Front(); e != nil; { + next := e.Next() + sub := e.Value.(*Subscription) + callback(sub, next == nil) + e = next + } +} diff --git a/backend/services/stomp/server/client/subscription_list_test.go b/backend/services/stomp/server/client/subscription_list_test.go new file mode 100644 index 0000000..8d471e6 --- /dev/null +++ b/backend/services/stomp/server/client/subscription_list_test.go @@ -0,0 +1,113 @@ +package client + +import ( + . "gopkg.in/check.v1" +) + +type SubscriptionListSuite struct{} + +var _ = Suite(&SubscriptionListSuite{}) + +func (s *SubscriptionListSuite) TestAddAndGet(c *C) { + sub1 := newSubscription(nil, "/dest", "1", "client") + sub2 := newSubscription(nil, "/dest", "2", "client") + sub3 := newSubscription(nil, "/dest", "3", "client") + + sl := NewSubscriptionList() + sl.Add(sub1) + sl.Add(sub2) + sl.Add(sub3) + + c.Check(sl.Get(), Equals, sub1) + + // add the subscription again, should go to the back + sl.Add(sub1) + + c.Check(sl.Get(), Equals, sub2) + c.Check(sl.Get(), Equals, sub3) + c.Check(sl.Get(), Equals, sub1) + + c.Check(sl.Get(), IsNil) +} + +func (s *SubscriptionListSuite) TestAddAndRemove(c *C) { + sub1 := newSubscription(nil, "/dest", "1", "client") + sub2 := newSubscription(nil, "/dest", "2", "client") + sub3 := newSubscription(nil, "/dest", "3", "client") + + sl := NewSubscriptionList() + sl.Add(sub1) + sl.Add(sub2) + sl.Add(sub3) + + c.Check(sl.subs.Len(), Equals, 3) + + // now remove the second subscription + sl.Remove(sub2) + + c.Check(sl.Get(), Equals, sub1) + c.Check(sl.Get(), Equals, sub3) + c.Check(sl.Get(), IsNil) +} + +func (s *SubscriptionListSuite) TestAck(c *C) { + sub1 := &Subscription{dest: "/dest1", id: "1", ack: "client", msgId: 101} + sub2 := &Subscription{dest: "/dest3", id: "2", ack: "client-individual", msgId: 102} + sub3 := &Subscription{dest: "/dest4", id: "3", ack: "client", msgId: 103} + sub4 := &Subscription{dest: "/dest4", id: "4", ack: "client", msgId: 104} + + sl := NewSubscriptionList() + sl.Add(sub1) + sl.Add(sub2) + sl.Add(sub3) + sl.Add(sub4) + + c.Check(sl.subs.Len(), Equals, 4) + + var subs []*Subscription + callback := func(s *Subscription) { + subs = append(subs, s) + } + + // now remove the second subscription + sl.Ack(103, callback) + + c.Assert(len(subs), Equals, 2) + c.Assert(subs[0], Equals, sub1) + c.Assert(subs[1], Equals, sub3) + + c.Assert(sl.Get(), Equals, sub2) + c.Assert(sl.Get(), Equals, sub4) + c.Assert(sl.Get(), IsNil) +} + +func (s *SubscriptionListSuite) TestNack(c *C) { + sub1 := &Subscription{dest: "/dest1", id: "1", ack: "client", msgId: 101} + sub2 := &Subscription{dest: "/dest3", id: "2", ack: "client-individual", msgId: 102} + sub3 := &Subscription{dest: "/dest4", id: "3", ack: "client", msgId: 103} + sub4 := &Subscription{dest: "/dest4", id: "4", ack: "client", msgId: 104} + + sl := NewSubscriptionList() + sl.Add(sub1) + sl.Add(sub2) + sl.Add(sub3) + sl.Add(sub4) + + c.Check(sl.subs.Len(), Equals, 4) + + var subs []*Subscription + callback := func(s *Subscription) { + subs = append(subs, s) + } + + // now remove the second subscription + sl.Nack(103, callback) + + c.Assert(len(subs), Equals, 1) + c.Assert(subs[0], Equals, sub3) + + c.Assert(sl.Get(), Equals, sub1) + c.Assert(sl.Get(), Equals, sub2) + c.Assert(sl.Get(), Equals, sub4) + c.Assert(sl.Get(), IsNil) +} diff --git a/backend/services/stomp/server/client/tx_store.go b/backend/services/stomp/server/client/tx_store.go new file mode 100644 index 0000000..d2c4354 --- /dev/null +++ b/backend/services/stomp/server/client/tx_store.go @@ -0,0 +1,65 @@ +package client + +import ( + "container/list" + + "github.com/go-stomp/stomp/v3/frame" +) + +type txStore struct { + transactions map[string]*list.List +} + +// Initializes a new store or clears out an existing store +func (txs *txStore) Init() { + txs.transactions = nil +} + +func (txs *txStore) Begin(tx string) error { + if txs.transactions == nil { + txs.transactions = make(map[string]*list.List) + } + + if _, ok := txs.transactions[tx]; ok { + return txAlreadyInProgress + } + + txs.transactions[tx] = list.New() + return nil +} + +func (txs *txStore) Abort(tx string) error { + if list, ok := txs.transactions[tx]; ok { + list.Init() + delete(txs.transactions, tx) + return nil + } + return txUnknown +} + +// Commit causes all requests that have been queued for the transaction +// to be sent to the request channel for processing. Calls the commit +// function (commitFunc) in order for each request that is part of the +// transaction. +func (txs *txStore) Commit(tx string, commitFunc func(f *frame.Frame) error) error { + if list, ok := txs.transactions[tx]; ok { + for element := list.Front(); element != nil; element = list.Front() { + err := commitFunc(list.Remove(element).(*frame.Frame)) + if err != nil { + return err + } + } + delete(txs.transactions, tx) + return nil + } + return txUnknown +} + +func (txs *txStore) Add(tx string, f *frame.Frame) error { + if list, ok := txs.transactions[tx]; ok { + f.Header.Del(frame.Transaction) + list.PushBack(f) + return nil + } + return txUnknown +} diff --git a/backend/services/stomp/server/client/tx_store_test.go b/backend/services/stomp/server/client/tx_store_test.go new file mode 100644 index 0000000..eca2774 --- /dev/null +++ b/backend/services/stomp/server/client/tx_store_test.go @@ -0,0 +1,81 @@ +package client + +import ( + "github.com/go-stomp/stomp/v3/frame" + . "gopkg.in/check.v1" +) + +type TxStoreSuite struct{} + +var _ = Suite(&TxStoreSuite{}) + +func (s *TxStoreSuite) TestDoubleBegin(c *C) { + txs := txStore{} + + err := txs.Begin("tx1") + c.Assert(err, IsNil) + + err = txs.Begin("tx1") + c.Assert(err, Equals, txAlreadyInProgress) +} + +func (s *TxStoreSuite) TestSuccessfulTx(c *C) { + txs := txStore{} + + err := txs.Begin("tx1") + c.Check(err, IsNil) + + err = txs.Begin("tx2") + c.Assert(err, IsNil) + + f1 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/1") + + f2 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/2") + + f3 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/3") + + f4 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/4") + + err = txs.Add("tx1", f1) + c.Assert(err, IsNil) + err = txs.Add("tx1", f2) + c.Assert(err, IsNil) + err = txs.Add("tx1", f3) + c.Assert(err, IsNil) + err = txs.Add("tx2", f4) + + var tx1 []*frame.Frame + + txs.Commit("tx1", func(f *frame.Frame) error { + tx1 = append(tx1, f) + return nil + }) + c.Check(err, IsNil) + + var tx2 []*frame.Frame + + err = txs.Commit("tx2", func(f *frame.Frame) error { + tx2 = append(tx2, f) + return nil + }) + c.Check(err, IsNil) + + c.Check(len(tx1), Equals, 3) + c.Check(tx1[0], Equals, f1) + c.Check(tx1[1], Equals, f2) + c.Check(tx1[2], Equals, f3) + + c.Check(len(tx2), Equals, 1) + c.Check(tx2[0], Equals, f4) + + // already committed, so should cause an error + err = txs.Commit("tx1", func(f *frame.Frame) error { + c.Fatal("should not be called") + return nil + }) + c.Check(err, Equals, txUnknown) +} diff --git a/backend/services/stomp/server/client/util.go b/backend/services/stomp/server/client/util.go new file mode 100644 index 0000000..683a9da --- /dev/null +++ b/backend/services/stomp/server/client/util.go @@ -0,0 +1,21 @@ +package client + +import ( + "time" +) + +// Convert a time.Duration to milliseconds in an integer. +// Returns the duration in milliseconds, or max if the +// duration is greater than max milliseconds. +func asMilliseconds(d time.Duration, max int) int { + if max < 0 { + max = 0 + } + max64 := int64(max) + msec64 := int64(d / time.Millisecond) + if msec64 > max64 { + msec64 = max64 + } + msec := int(msec64) + return msec +} diff --git a/backend/services/stomp/server/client/util_test.go b/backend/services/stomp/server/client/util_test.go new file mode 100644 index 0000000..c577f61 --- /dev/null +++ b/backend/services/stomp/server/client/util_test.go @@ -0,0 +1,23 @@ +package client + +import ( + . "gopkg.in/check.v1" + "math" + "time" +) + +type UtilSuite struct{} + +var _ = Suite(&UtilSuite{}) + +func (s *UtilSuite) TestAsMilliseconds(c *C) { + d := time.Duration(30) * time.Millisecond + c.Check(asMilliseconds(d, math.MaxInt32), Equals, 30) + + // approximately one year + d = time.Duration(365) * time.Duration(24) * time.Hour + c.Check(asMilliseconds(d, math.MaxInt32), Equals, math.MaxInt32) + + d = time.Duration(365) * time.Duration(24) * time.Hour + c.Check(asMilliseconds(d, maxHeartBeat), Equals, maxHeartBeat) +} diff --git a/backend/services/stomp/server/processor.go b/backend/services/stomp/server/processor.go new file mode 100644 index 0000000..291845a --- /dev/null +++ b/backend/services/stomp/server/processor.go @@ -0,0 +1,158 @@ +package server + +import ( + "net" + "strings" + "time" + + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" + "github.com/go-stomp/stomp/v3/server/client" + "github.com/go-stomp/stomp/v3/server/queue" + "github.com/go-stomp/stomp/v3/server/topic" +) + +type requestProcessor struct { + server *Server + ch chan client.Request + tm *topic.Manager + qm *queue.Manager + stop bool // has stop been requested +} + +func newRequestProcessor(server *Server) *requestProcessor { + proc := &requestProcessor{ + server: server, + ch: make(chan client.Request, 128), + tm: topic.NewManager(), + } + + if server.QueueStorage == nil { + proc.qm = queue.NewManager(queue.NewMemoryQueueStorage()) + } else { + proc.qm = queue.NewManager(server.QueueStorage) + } + + return proc +} + +func (proc *requestProcessor) Serve(l net.Listener) error { + go proc.Listen(l) + + for { + r := <-proc.ch + switch r.Op { + case client.SubscribeOp: + if isQueueDestination(r.Sub.Destination()) { + queue := proc.qm.Find(r.Sub.Destination()) + // todo error handling + queue.Subscribe(r.Sub) + } else { + topic := proc.tm.Find(r.Sub.Destination()) + topic.Subscribe(r.Sub) + } + + case client.UnsubscribeOp: + if isQueueDestination(r.Sub.Destination()) { + queue := proc.qm.Find(r.Sub.Destination()) + // todo error handling + queue.Unsubscribe(r.Sub) + } else { + topic := proc.tm.Find(r.Sub.Destination()) + topic.Unsubscribe(r.Sub) + } + + case client.EnqueueOp: + destination, ok := r.Frame.Header.Contains(frame.Destination) + if !ok { + // should not happen, already checked in lower layer + panic("missing destination") + } + + if isQueueDestination(destination) { + queue := proc.qm.Find(destination) + queue.Enqueue(r.Frame) + } else { + topic := proc.tm.Find(destination) + topic.Enqueue(r.Frame) + } + + case client.RequeueOp: + destination, ok := r.Frame.Header.Contains(frame.Destination) + if !ok { + // should not happen, already checked in lower layer + panic("missing destination") + } + + // only requeue to queues, should never happen for topics + if isQueueDestination(destination) { + queue := proc.qm.Find(destination) + queue.Requeue(r.Frame) + } + } + } + // this is no longer required for go 1.1 + panic("not reached") +} + +func isQueueDestination(dest string) bool { + return strings.HasPrefix(dest, QueuePrefix) +} + +func (proc *requestProcessor) Listen(l net.Listener) { + config := newConfig(proc.server) + timeout := time.Duration(0) // how long to sleep on accept failure + for { + rw, err := l.Accept() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Temporary() { + if timeout == 0 { + timeout = 5 * time.Millisecond + } else { + timeout *= 2 + } + if max := 5 * time.Second; timeout > max { + timeout = max + } + proc.server.Log.Infof("stomp: Accept error: %v; retrying in %v", err, timeout) + time.Sleep(timeout) + continue + } + return + } + timeout = 0 + // TODO: need to pass Server to connection so it has access to + // configuration parameters. + _ = client.NewConn(config, rw, proc.ch) + } + // This is no longer required for go 1.1 + panic("not reached") +} + +type config struct { + server *Server +} + +func newConfig(s *Server) *config { + return &config{server: s} +} + +func (c *config) HeartBeat() time.Duration { + if c.server.HeartBeat == time.Duration(0) { + return DefaultHeartBeat + } + return c.server.HeartBeat +} + +func (c *config) Authenticate(login, passcode string) bool { + if c.server.Authenticator != nil { + return c.server.Authenticator.Authenticate(login, passcode) + } + + // no authentication defined + return true +} + +func (c *config) Logger() stomp.Logger { + return c.server.Log +} diff --git a/backend/services/stomp/server/queue/manager.go b/backend/services/stomp/server/queue/manager.go new file mode 100644 index 0000000..3e499df --- /dev/null +++ b/backend/services/stomp/server/queue/manager.go @@ -0,0 +1,23 @@ +package queue + +// Queue manager. +type Manager struct { + qstore Storage // handles queue storage + queues map[string]*Queue +} + +// Create a queue manager with the specified queue storage mechanism +func NewManager(qstore Storage) *Manager { + qm := &Manager{qstore: qstore, queues: make(map[string]*Queue)} + return qm +} + +// Finds the queue for the given destination, and creates it if necessary. +func (qm *Manager) Find(destination string) *Queue { + q, ok := qm.queues[destination] + if !ok { + q = newQueue(destination, qm.qstore) + qm.queues[destination] = q + } + return q +} diff --git a/backend/services/stomp/server/queue/manager_test.go b/backend/services/stomp/server/queue/manager_test.go new file mode 100644 index 0000000..d74bd54 --- /dev/null +++ b/backend/services/stomp/server/queue/manager_test.go @@ -0,0 +1,21 @@ +package queue + +import ( + . "gopkg.in/check.v1" +) + +type ManagerSuite struct{} + +var _ = Suite(&ManagerSuite{}) + +func (s *ManagerSuite) TestManager(c *C) { + mgr := NewManager(NewMemoryQueueStorage()) + + q1 := mgr.Find("/queue/1") + c.Assert(q1, NotNil) + + q2 := mgr.Find("/queue/2") + c.Assert(q2, NotNil) + + c.Assert(mgr.Find("/queue/1"), Equals, q1) +} diff --git a/backend/services/stomp/server/queue/memory_queue.go b/backend/services/stomp/server/queue/memory_queue.go new file mode 100644 index 0000000..ebd204e --- /dev/null +++ b/backend/services/stomp/server/queue/memory_queue.go @@ -0,0 +1,70 @@ +package queue + +import ( + "container/list" + + "github.com/go-stomp/stomp/v3/frame" +) + +// In-memory implementation of the QueueStorage interface. +type MemoryQueueStorage struct { + lists map[string]*list.List +} + +func NewMemoryQueueStorage() Storage { + m := &MemoryQueueStorage{lists: make(map[string]*list.List)} + return m +} + +func (m *MemoryQueueStorage) Enqueue(queue string, frame *frame.Frame) error { + l, ok := m.lists[queue] + if !ok { + l = list.New() + m.lists[queue] = l + } + l.PushBack(frame) + + return nil +} + +// Pushes a frame to the head of the queue. Sets +// the "message-id" header of the frame if it is not +// already set. +func (m *MemoryQueueStorage) Requeue(queue string, frame *frame.Frame) error { + l, ok := m.lists[queue] + if !ok { + l = list.New() + m.lists[queue] = l + } + l.PushFront(frame) + + return nil +} + +// Removes a frame from the head of the queue. +// Returns nil if no frame is available. +func (m *MemoryQueueStorage) Dequeue(queue string) (*frame.Frame, error) { + l, ok := m.lists[queue] + if !ok { + return nil, nil + } + + element := l.Front() + if element == nil { + return nil, nil + } + + return l.Remove(element).(*frame.Frame), nil +} + +// Called at server startup. Allows the queue storage +// to perform any initialization. +func (m *MemoryQueueStorage) Start() { + m.lists = make(map[string]*list.List) +} + +// Called prior to server shutdown. Allows the queue storage +// to perform any cleanup. +func (m *MemoryQueueStorage) Stop() { + m.lists = nil +} diff --git a/backend/services/stomp/server/queue/memory_queue_test.go b/backend/services/stomp/server/queue/memory_queue_test.go new file mode 100644 index 0000000..7037291 --- /dev/null +++ b/backend/services/stomp/server/queue/memory_queue_test.go @@ -0,0 +1,64 @@ +package queue + +import ( + "github.com/go-stomp/stomp/v3/frame" + . "gopkg.in/check.v1" +) + +type MemoryQueueSuite struct{} + +var _ = Suite(&MemoryQueueSuite{}) + +func (s *MemoryQueueSuite) Test1(c *C) { + mq := NewMemoryQueueStorage() + mq.Start() + + f1 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/test", + frame.MessageId, "msg-001", + frame.Subscription, "1") + + err := mq.Enqueue("/queue/test", f1) + c.Assert(err, IsNil) + + f2 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/test", + frame.MessageId, "msg-002", + frame.Subscription, "1") + + err = mq.Enqueue("/queue/test", f2) + c.Assert(err, IsNil) + + f3 := frame.New(frame.MESSAGE, + frame.Destination, "/queue/test2", + frame.MessageId, "msg-003", + frame.Subscription, "2") + + err = mq.Enqueue("/queue/test2", f3) + c.Assert(err, IsNil) + + // attempt to dequeue from a different queue + f, err := mq.Dequeue("/queue/other-queue") + c.Check(err, IsNil) + c.Assert(f, IsNil) + + f, err = mq.Dequeue("/queue/test2") + c.Check(err, IsNil) + c.Assert(f, Equals, f3) + + f, err = mq.Dequeue("/queue/test") + c.Check(err, IsNil) + c.Assert(f, Equals, f1) + + f, err = mq.Dequeue("/queue/test") + c.Check(err, IsNil) + c.Assert(f, Equals, f2) + + f, err = mq.Dequeue("/queue/test") + c.Check(err, IsNil) + c.Assert(f, IsNil) + + f, err = mq.Dequeue("/queue/test2") + c.Check(err, IsNil) + c.Assert(f, IsNil) +} diff --git a/backend/services/stomp/server/queue/queue.go b/backend/services/stomp/server/queue/queue.go new file mode 100644 index 0000000..6b7beab --- /dev/null +++ b/backend/services/stomp/server/queue/queue.go @@ -0,0 +1,86 @@ +/* +Package queue provides implementations of server-side queues. +*/ +package queue + +import ( + "github.com/go-stomp/stomp/v3/frame" + "github.com/go-stomp/stomp/v3/server/client" +) + +// Queue for storing message frames. +type Queue struct { + destination string + qstore Storage + subs *client.SubscriptionList +} + +// Create a new queue -- called from the queue manager only. +func newQueue(destination string, qstore Storage) *Queue { + return &Queue{ + destination: destination, + qstore: qstore, + subs: client.NewSubscriptionList(), + } +} + +// Add a subscription to a queue. The subscription is removed +// whenever a frame is sent to the subscription and needs to +// be re-added when the subscription decides that the message +// has been received by the client. +func (q *Queue) Subscribe(sub *client.Subscription) error { + // see if there is a frame available for this subscription + f, err := q.qstore.Dequeue(sub.Destination()) + if err != nil { + return err + } + if f == nil { + // no frame available, so add to the subscription list + q.subs.Add(sub) + } else { + // a frame is available, so send straight away without + // adding the subscription to the list + sub.SendQueueFrame(f) + } + return nil +} + +// Unsubscribe a subscription. +func (q *Queue) Unsubscribe(sub *client.Subscription) { + q.subs.Remove(sub) +} + +// Send a message to the queue. If a subscription is available +// to receive the message, it is sent to the subscription without +// making it to the queue. Otherwise, the message is queued until +// a message is available. +func (q *Queue) Enqueue(f *frame.Frame) error { + // find a subscription ready to receive the frame + sub := q.subs.Get() + if sub == nil { + // no subscription available, add to the queue + return q.qstore.Enqueue(q.destination, f) + } else { + // subscription is available, send it now without adding to queue + sub.SendQueueFrame(f) + } + return nil +} + +// Send a message to the front of the queue, probably because it +// failed to be sent to a client. If a subscription is available +// to receive the message, it is sent to the subscription without +// making it to the queue. Otherwise, the message is queued until +// a message is available. +func (q *Queue) Requeue(f *frame.Frame) error { + // find a subscription ready to receive the frame + sub := q.subs.Get() + if sub == nil { + // no subscription available, add to the queue + return q.qstore.Requeue(q.destination, f) + } else { + // subscription is available, send it now without adding to queue + sub.SendQueueFrame(f) + } + return nil +} diff --git a/backend/services/stomp/server/queue/queue_test.go b/backend/services/stomp/server/queue/queue_test.go new file mode 100644 index 0000000..8517172 --- /dev/null +++ b/backend/services/stomp/server/queue/queue_test.go @@ -0,0 +1,12 @@ +package queue + +import ( + "gopkg.in/check.v1" + "testing" +) + +// Runs all gocheck tests in this package. +// See other *_test.go files for gocheck tests. +func TestQueue(t *testing.T) { + check.TestingT(t) +} diff --git a/backend/services/stomp/server/queue/storage.go b/backend/services/stomp/server/queue/storage.go new file mode 100644 index 0000000..9d6eead --- /dev/null +++ b/backend/services/stomp/server/queue/storage.go @@ -0,0 +1,34 @@ +package queue + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// Interface for queue storage. The intent is that +// different queue storage implementations can be +// used, depending on preference. Queue storage +// mechanisms could include in-memory, and various +// persistent storage mechanisms (eg file system, DB, etc) +type Storage interface { + // Pushes a MESSAGE frame to the end of the queue. Sets + // the "message-id" header of the frame before adding to + // the queue. + Enqueue(queue string, frame *frame.Frame) error + + // Pushes a MESSAGE frame to the head of the queue. Sets + // the "message-id" header of the frame if it is not + // already set. + Requeue(queue string, frame *frame.Frame) error + + // Removes a frame from the head of the queue. + // Returns nil if no frame is available. + Dequeue(queue string) (*frame.Frame, error) + + // Called at server startup. Allows the queue storage + // to perform any initialization. + Start() + + // Called prior to server shutdown. Allows the queue storage + // to perform any cleanup. + Stop() +} diff --git a/backend/services/stomp/server/queue_storage.go b/backend/services/stomp/server/queue_storage.go new file mode 100644 index 0000000..d922597 --- /dev/null +++ b/backend/services/stomp/server/queue_storage.go @@ -0,0 +1,30 @@ +package server + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// QueueStorage is an interface that abstracts the queue storage mechanism. +// The intent is that different queue storage implementations can be +// used, depending on preference. Queue storage mechanisms could include +// in-memory, and various persistent storage mechanisms (eg file system, DB, etc). +type QueueStorage interface { + // Enqueue adds a MESSAGE frame to the end of the queue. + Enqueue(queue string, frame *frame.Frame) error + + // Requeue adds a MESSAGE frame to the head of the queue. + // This will happen if a client fails to acknowledge receipt. + Requeue(queue string, frame *frame.Frame) error + + // Dequeue removes a frame from the head of the queue. + // Returns nil if no frame is available. + Dequeue(queue string) (*frame.Frame, error) + + // Start is called at server startup. Allows the queue storage + // to perform any initialization. + Start() + + // Stop is called prior to server shutdown. Allows the queue storage + // to perform any cleanup, such as flushing to disk. + Stop() +} diff --git a/backend/services/stomp/server/server.go b/backend/services/stomp/server/server.go new file mode 100644 index 0000000..dacde5a --- /dev/null +++ b/backend/services/stomp/server/server.go @@ -0,0 +1,89 @@ +/* +Package server contains a simple STOMP server implementation. +*/ +package server + +import ( + "net" + "time" + + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/internal/log" +) + +// The STOMP server has the concept of queues and topics. A message +// sent to a queue destination will be transmitted to the next available +// client that has subscribed. A message sent to a topic will be +// transmitted to all subscribers that are currently subscribed to the +// topic. +// +// Destinations that start with this prefix are considered to be queues. +// Destinations that do not start with this prefix are considered to be topics. +const QueuePrefix = "/queue" + +// Default server parameters. +const ( + // Default address for listening for connections. + DefaultAddr = ":61613" + + // Default read timeout for heart-beat. + // Override by setting Server.HeartBeat. + DefaultHeartBeat = time.Minute +) + +// Interface for authenticating STOMP clients. +type Authenticator interface { + // Authenticate based on the given login and passcode, either of which might be nil. + // Returns true if authentication is successful, false otherwise. + Authenticate(login, passcode string) bool +} + +// A Server defines parameters for running a STOMP server. +type Server struct { + Addr string // TCP address to listen on, DefaultAddr if empty + Authenticator Authenticator // Authenticates login/passcodes. If nil no authentication is performed + QueueStorage QueueStorage // Implementation of queue storage. If nil, in-memory queues are used. + HeartBeat time.Duration // Preferred value for heart-beat read/write timeout, if zero, then DefaultHeartBeat. + Log stomp.Logger +} + +// ListenAndServe listens on the TCP network address addr and then calls Serve. +func ListenAndServe(addr string) error { + s := &Server{Addr: addr} + return s.ListenAndServe() +} + +// Serve accepts incoming TCP connections on the listener l, creating a new +// STOMP service thread for each connection. +func Serve(l net.Listener) error { + s := &Server{} + return s.Serve(l) +} + +// ListenAndServe listens on the TCP network address s.Addr and +// then calls Serve to handle requests on the incoming connections. +// If s.Addr is blank, then DefaultAddr is used. +func (s *Server) ListenAndServe() error { + addr := s.Addr + if addr == "" { + addr = DefaultAddr + } + l, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + return s.Serve(l) +} + +// Serve accepts incoming connections on the Listener l, creating a new +// service thread for each connection. The service threads read +// requests and then process each request. +func (s *Server) Serve(l net.Listener) error { + if s.Log == nil { + s.Log = log.StdLogger{} + } + + proc := newRequestProcessor(s) + return proc.Serve(l) +} diff --git a/backend/services/stomp/server/server_test.go b/backend/services/stomp/server/server_test.go new file mode 100644 index 0000000..9ab099e --- /dev/null +++ b/backend/services/stomp/server/server_test.go @@ -0,0 +1,170 @@ +package server + +import ( + "fmt" + "net" + "runtime" + "testing" + "time" + + "github.com/go-stomp/stomp/v3" + . "gopkg.in/check.v1" +) + +func TestServer(t *testing.T) { + TestingT(t) +} + +type ServerSuite struct{} + +var _ = Suite(&ServerSuite{}) + +func (s *ServerSuite) SetUpSuite(c *C) { + runtime.GOMAXPROCS(runtime.NumCPU()) +} + +func (s *ServerSuite) TearDownSuite(c *C) { + runtime.GOMAXPROCS(1) +} + +func (s *ServerSuite) TestConnectAndDisconnect(c *C) { + addr := ":59091" + l, err := net.Listen("tcp", addr) + c.Assert(err, IsNil) + defer func() { l.Close() }() + go Serve(l) + + conn, err := net.Dial("tcp", "127.0.0.1"+addr) + c.Assert(err, IsNil) + + client, err := stomp.Connect(conn) + c.Assert(err, IsNil) + + err = client.Disconnect() + c.Assert(err, IsNil) + + conn.Close() +} + + +func (s *ServerSuite) TestHeartBeatingTolerance(c *C) { + // Heart beat should not close connection exactly after not receiving message after cx + // it should add a pretty decent amount of time to counter network delay of other timing issues + l, err := net.Listen("tcp", `127.0.0.1:0`) + c.Assert(err, IsNil) + defer func() { l.Close() }() + serv := Server{ + Addr: l.Addr().String(), + Authenticator: nil, + QueueStorage: nil, + HeartBeat: 5 * time.Millisecond, + } + go serv.Serve(l) + + conn, err := net.Dial("tcp", l.Addr().String()) + c.Assert(err, IsNil) + defer conn.Close() + + client, err := stomp.Connect(conn, + stomp.ConnOpt.HeartBeat(5 * time.Millisecond, 5 * time.Millisecond), + ) + c.Assert(err, IsNil) + defer client.Disconnect() + + time.Sleep(serv.HeartBeat * 20) // let it go for some time to allow client and server to exchange some heart beat + + // Ensure the server has not closed his readChannel + err = client.Send("/topic/whatever", "text/plain", []byte("hello")) + c.Assert(err, IsNil) +} + +func (s *ServerSuite) TestSendToQueuesAndTopics(c *C) { + ch := make(chan bool, 2) + println("number cpus:", runtime.NumCPU()) + + addr := ":59092" + + l, err := net.Listen("tcp", addr) + c.Assert(err, IsNil) + defer func() { l.Close() }() + go Serve(l) + + // channel to communicate that the go routine has started + started := make(chan bool) + + count := 100 + go runReceiver(c, ch, count, "/topic/test-1", addr, started) + <-started + go runReceiver(c, ch, count, "/topic/test-1", addr, started) + <-started + go runReceiver(c, ch, count, "/topic/test-2", addr, started) + <-started + go runReceiver(c, ch, count, "/topic/test-2", addr, started) + <-started + go runReceiver(c, ch, count, "/topic/test-1", addr, started) + <-started + go runReceiver(c, ch, count, "/queue/test-1", addr, started) + <-started + go runSender(c, ch, count, "/queue/test-1", addr, started) + <-started + go runSender(c, ch, count, "/queue/test-2", addr, started) + <-started + go runReceiver(c, ch, count, "/queue/test-2", addr, started) + <-started + go runSender(c, ch, count, "/topic/test-1", addr, started) + <-started + go runReceiver(c, ch, count, "/queue/test-3", addr, started) + <-started + go runSender(c, ch, count, "/queue/test-3", addr, started) + <-started + go runSender(c, ch, count, "/queue/test-4", addr, started) + <-started + go runSender(c, ch, count, "/topic/test-2", addr, started) + <-started + go runReceiver(c, ch, count, "/queue/test-4", addr, started) + <-started + + for i := 0; i < 15; i++ { + <-ch + } +} + +func runSender(c *C, ch chan bool, count int, destination, addr string, started chan bool) { + conn, err := net.Dial("tcp", "127.0.0.1"+addr) + c.Assert(err, IsNil) + + client, err := stomp.Connect(conn) + c.Assert(err, IsNil) + + started <- true + + for i := 0; i < count; i++ { + client.Send(destination, "text/plain", + []byte(fmt.Sprintf("%s test message %d", destination, i))) + //println("sent", i) + } + + ch <- true +} + +func runReceiver(c *C, ch chan bool, count int, destination, addr string, started chan bool) { + conn, err := net.Dial("tcp", "127.0.0.1"+addr) + c.Assert(err, IsNil) + + client, err := stomp.Connect(conn) + c.Assert(err, IsNil) + + sub, err := client.Subscribe(destination, stomp.AckAuto) + c.Assert(err, IsNil) + c.Assert(sub, NotNil) + + started <- true + + for i := 0; i < count; i++ { + msg := <-sub.C + expectedText := fmt.Sprintf("%s test message %d", destination, i) + c.Assert(msg.Body, DeepEquals, []byte(expectedText)) + //println("received", i) + } + ch <- true +} diff --git a/backend/services/stomp/server/topic/manager.go b/backend/services/stomp/server/topic/manager.go new file mode 100644 index 0000000..0b9611d --- /dev/null +++ b/backend/services/stomp/server/topic/manager.go @@ -0,0 +1,24 @@ +package topic + +// Manager is a struct responsible for finding topics. Topics are +// not created by the package user, rather they are created on demand +// by the topic manager. +type Manager struct { + topics map[string]*Topic +} + +// NewManager creates a new topic manager. +func NewManager() *Manager { + tm := &Manager{topics: make(map[string]*Topic)} + return tm +} + +// Finds the topic for the given destination, and creates it if necessary. +func (tm *Manager) Find(destination string) *Topic { + t, ok := tm.topics[destination] + if !ok { + t = newTopic(destination) + tm.topics[destination] = t + } + return t +} diff --git a/backend/services/stomp/server/topic/manager_test.go b/backend/services/stomp/server/topic/manager_test.go new file mode 100644 index 0000000..2fa76df --- /dev/null +++ b/backend/services/stomp/server/topic/manager_test.go @@ -0,0 +1,21 @@ +package topic + +import ( + . "gopkg.in/check.v1" +) + +type ManagerSuite struct{} + +var _ = Suite(&ManagerSuite{}) + +func (s *ManagerSuite) TestManager(c *C) { + mgr := NewManager() + + t1 := mgr.Find("topic1") + c.Assert(t1, NotNil) + + t2 := mgr.Find("topic2") + c.Assert(t2, NotNil) + + c.Assert(mgr.Find("topic1"), Equals, t1) +} diff --git a/backend/services/stomp/server/topic/subscription.go b/backend/services/stomp/server/topic/subscription.go new file mode 100644 index 0000000..956c347 --- /dev/null +++ b/backend/services/stomp/server/topic/subscription.go @@ -0,0 +1,11 @@ +package topic + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// Subscription is the interface that wraps a subscriber to a topic. +type Subscription interface { + // Send a message frame to the topic subscriber. + SendTopicFrame(f *frame.Frame) +} diff --git a/backend/services/stomp/server/topic/testing_test.go b/backend/services/stomp/server/topic/testing_test.go new file mode 100644 index 0000000..21ec5c6 --- /dev/null +++ b/backend/services/stomp/server/topic/testing_test.go @@ -0,0 +1,12 @@ +package topic + +import ( + "gopkg.in/check.v1" + "testing" +) + +// Runs all gocheck tests in this package. +// See other *_test.go files for gocheck tests. +func Test(t *testing.T) { + check.TestingT(t) +} diff --git a/backend/services/stomp/server/topic/topic.go b/backend/services/stomp/server/topic/topic.go new file mode 100644 index 0000000..afe4c85 --- /dev/null +++ b/backend/services/stomp/server/topic/topic.go @@ -0,0 +1,73 @@ +/* +Package topic provides implementations of server-side topics. +*/ +package topic + +import ( + "container/list" + + "github.com/go-stomp/stomp/v3/frame" +) + +// A Topic is used for broadcasting to subscribed clients. +// In contrast to a queue, when a message is sent to a topic, +// that message is transmitted to all subscribed clients. +type Topic struct { + destination string + subs *list.List +} + +// Create a new topic -- called from the topic manager only. +func newTopic(destination string) *Topic { + return &Topic{ + destination: destination, + subs: list.New(), + } +} + +// Subscribe adds a subscription to a topic. Any message sent to the +// topic will be transmitted to the subscription's client until +// unsubscription occurs. +func (t *Topic) Subscribe(sub Subscription) { + t.subs.PushBack(sub) +} + +// Unsubscribe causes a subscription to be removed from the topic. +func (t *Topic) Unsubscribe(sub Subscription) { + for e := t.subs.Front(); e != nil; e = e.Next() { + if sub == e.Value.(Subscription) { + t.subs.Remove(e) + return + } + } +} + +// Enqueue send a message to the topic. All subscriptions receive a copy +// of the message. +func (t *Topic) Enqueue(f *frame.Frame) { + switch t.subs.Len() { + case 0: + // no subscription, so do nothing + + case 1: + // only one subscription, so can send the frame + // without copying + sub := t.subs.Front().Value.(Subscription) + sub.SendTopicFrame(f) + + default: + // more than one subscription, send clone for + // all subscriptions except the last, which can + // have the frame without copying + for e := t.subs.Front(); e != nil; e = e.Next() { + sub := e.Value.(Subscription) + if e.Next() == nil { + // the last in the list, send the frame + // without copying + sub.SendTopicFrame(f) + } else { + sub.SendTopicFrame(f.Clone()) + } + } + } +} diff --git a/backend/services/stomp/server/topic/topic_test.go b/backend/services/stomp/server/topic/topic_test.go new file mode 100644 index 0000000..863c0d1 --- /dev/null +++ b/backend/services/stomp/server/topic/topic_test.go @@ -0,0 +1,63 @@ +package topic + +import ( + "github.com/go-stomp/stomp/v3/frame" + . "gopkg.in/check.v1" +) + +type TopicSuite struct{} + +var _ = Suite(&TopicSuite{}) + +func (s *TopicSuite) TestTopicWithoutSubscription(c *C) { + topic := newTopic("destination") + + f := frame.New(frame.MESSAGE, + frame.Destination, "destination") + + topic.Enqueue(f) +} + +func (s *TopicSuite) TestTopicWithOneSubscription(c *C) { + sub := &fakeSubscription{} + + topic := newTopic("destination") + topic.Subscribe(sub) + + f := frame.New(frame.MESSAGE, + frame.Destination, "destination") + + topic.Enqueue(f) + + c.Assert(len(sub.Frames), Equals, 1) + c.Assert(sub.Frames[0], Equals, f) +} + +func (s *TopicSuite) TestTopicWithTwoSubscriptions(c *C) { + sub1 := &fakeSubscription{} + sub2 := &fakeSubscription{} + + topic := newTopic("destination") + topic.Subscribe(sub1) + topic.Subscribe(sub2) + + f := frame.New(frame.MESSAGE, + frame.Destination, "destination", + "xxx", "yyy") + + topic.Enqueue(f) + + c.Assert(len(sub1.Frames), Equals, 1) + c.Assert(len(sub2.Frames), Equals, 1) + c.Assert(sub1.Frames[0], Not(Equals), f) + c.Assert(sub2.Frames[0], Equals, f) +} + +type fakeSubscription struct { + // frames received by the subscription + Frames []*frame.Frame +} + +func (s *fakeSubscription) SendTopicFrame(f *frame.Frame) { + s.Frames = append(s.Frames, f) +} diff --git a/backend/services/stomp/stomp.go b/backend/services/stomp/stomp.go new file mode 100644 index 0000000..bbe0863 --- /dev/null +++ b/backend/services/stomp/stomp.go @@ -0,0 +1,26 @@ +/* +Package stomp provides operations that allow communication with a message broker that supports the STOMP protocol. +STOMP is the Streaming Text-Oriented Messaging Protocol. See http://stomp.github.com/ for more details. + +This package provides support for all STOMP protocol features in the STOMP protocol specifications, +versions 1.0, 1.1 and 1.2. These features including protocol negotiation, heart-beating, value encoding, +and graceful shutdown. + +Connecting to a STOMP server is achieved using the stomp.Dial function, or the stomp.Connect function. See +the examples section for a summary of how to use these functions. Both functions return a stomp.Conn object +for subsequent interaction with the STOMP server. + +Once a connection (stomp.Conn) is created, it can be used to send messages to the STOMP server, or create +subscriptions for receiving messages from the STOMP server. Transactions can be created to send multiple +messages and/ or acknowledge multiple received messages from the server in one, atomic transaction. The +examples section has examples of using subscriptions and transactions. + +The client program can instruct the stomp.Conn to gracefully disconnect from the STOMP server using the +Disconnect method. This will perform a graceful shutdown sequence as specified in the STOMP specification. + +Source code and other details for the project are available at GitHub: + + https://github.com/go-stomp/stomp + +*/ +package stomp diff --git a/backend/services/stomp/stomp_test.go b/backend/services/stomp/stomp_test.go new file mode 100644 index 0000000..f8feca0 --- /dev/null +++ b/backend/services/stomp/stomp_test.go @@ -0,0 +1,18 @@ +package stomp + +import ( + "testing" + + "gopkg.in/check.v1" +) + +// Runs all gocheck tests in this package. +// See other *_test.go files for gocheck tests. +func TestStomp(t *testing.T) { + check.Suite(&StompSuite{t}) + check.TestingT(t) +} + +type StompSuite struct { + t *testing.T +} diff --git a/backend/services/stomp/stompd/main.go b/backend/services/stomp/stompd/main.go new file mode 100644 index 0000000..8be250f --- /dev/null +++ b/backend/services/stomp/stompd/main.go @@ -0,0 +1,63 @@ +/* +A simple, stand-alone STOMP server. + +TODO: graceful shutdown + +TODO: UNIX daemon functionality + +TODO: Windows service functionality (if possible?) + +TODO: Logging options (syslog, windows event log) +*/ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + + "github.com/go-stomp/stomp/v3/server" +) + +// TODO: experimenting with ways to gracefully shutdown the server, +// at the moment it just dies ungracefully on SIGINT. + +/* + +func main() { + // create a channel for listening for termination signals + stopChannel := newStopChannel() + + for { + select { + case sig := <-stopChannel: + log.Println("received signal:", sig) + break + } + } + +} +*/ + +var listenAddr = flag.String("addr", ":61613", "Listen address") +var helpFlag = flag.Bool("help", false, "Show this help text") + +func main() { + flag.Parse() + if *helpFlag { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + flag.PrintDefaults() + os.Exit(1) + } + + l, err := net.Listen("tcp", *listenAddr) + if err != nil { + log.Fatalf("failed to listen: %s", err.Error()) + } + defer func() { l.Close() }() + + log.Println("listening on", l.Addr().Network(), l.Addr().String()) + server.Serve(l) +} diff --git a/backend/services/stomp/stompd/signals.go b/backend/services/stomp/stompd/signals.go new file mode 100644 index 0000000..1b1abd1 --- /dev/null +++ b/backend/services/stomp/stompd/signals.go @@ -0,0 +1,19 @@ +package main + +import ( + "os" + "os/signal" +) + +// newStopChannel creates a channel for receiving signals +// for stopping the program. Calls an os-dependent setupStopSignals +// function. +func newStopChannel() chan os.Signal { + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt) + + // os dependent between windows and unix + setupStopSignals(c) + + return c +} diff --git a/backend/services/stomp/stompd/signals_unix.go b/backend/services/stomp/stompd/signals_unix.go new file mode 100644 index 0000000..4c060ca --- /dev/null +++ b/backend/services/stomp/stompd/signals_unix.go @@ -0,0 +1,17 @@ +package main + +import ( + "os" + "os/signal" + "syscall" +) + +// setupStopSignals sets up UNIX-specific signals for terminating +// the program +func setupStopSignals(signalChannel chan os.Signal) { + // TODO: not sure whether SIGHUP should be used here, only if not in + // daemon mode + signal.Notify(signalChannel, syscall.SIGHUP) + + signal.Notify(signalChannel, syscall.SIGTERM) +} diff --git a/backend/services/stomp/stompd/signals_windows.go b/backend/services/stomp/stompd/signals_windows.go new file mode 100644 index 0000000..1454a92 --- /dev/null +++ b/backend/services/stomp/stompd/signals_windows.go @@ -0,0 +1,13 @@ +package main + +import ( + "os" +) + +func signals(signalChannel chan os.Signal) { + // Windows has no other signals other than os.Interrupt + + // TODO: What might be good here is to simulate a signal + // if running as a Windows service and the stop request is + // received. Not sure how to do this though. +} diff --git a/backend/services/stomp/stompd/stompd b/backend/services/stomp/stompd/stompd new file mode 100755 index 0000000..6a48d26 Binary files /dev/null and b/backend/services/stomp/stompd/stompd differ diff --git a/backend/services/stomp/subscribe_options.go b/backend/services/stomp/subscribe_options.go new file mode 100644 index 0000000..e5a5b18 --- /dev/null +++ b/backend/services/stomp/subscribe_options.go @@ -0,0 +1,42 @@ +package stomp + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// SubscribeOpt contains options for for the Conn.Subscribe function. +var SubscribeOpt struct { + // Id provides the opportunity to specify the value of the "id" header + // entry in the STOMP SUBSCRIBE frame. + // + // If the client program does specify the value for "id", + // it is responsible for choosing a unique value. + Id func(id string) func(*frame.Frame) error + + // Header provides the opportunity to include custom header entries + // in the SUBSCRIBE frame that the client sends to the server. + Header func(key, value string) func(*frame.Frame) error +} + +func init() { + SubscribeOpt.Id = func(id string) func(*frame.Frame) error { + return func(f *frame.Frame) error { + if f.Command != frame.SUBSCRIBE { + return ErrInvalidCommand + } + f.Header.Set(frame.Id, id) + return nil + } + } + + SubscribeOpt.Header = func(key, value string) func(*frame.Frame) error { + return func(f *frame.Frame) error { + if f.Command != frame.SUBSCRIBE && + f.Command != frame.UNSUBSCRIBE { + return ErrInvalidCommand + } + f.Header.Add(key, value) + return nil + } + } +} diff --git a/backend/services/stomp/subscription.go b/backend/services/stomp/subscription.go new file mode 100644 index 0000000..6aeaaf7 --- /dev/null +++ b/backend/services/stomp/subscription.go @@ -0,0 +1,183 @@ +package stomp + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/go-stomp/stomp/v3/frame" +) + +const ( + subStateActive = 0 + subStateClosing = 1 + subStateClosed = 2 +) + +// The Subscription type represents a client subscription to +// a destination. The subscription is created by calling Conn.Subscribe. +// +// Once a client has subscribed, it can receive messages from the C channel. +type Subscription struct { + C chan *Message + id string + destination string + conn *Conn + ackMode AckMode + state int32 + closeMutex *sync.Mutex + closeCond *sync.Cond +} + +// BUG(jpj): If the client does not read messages from the Subscription.C +// channel quickly enough, the client will stop reading messages from the +// server. + +// Identification for this subscription. Unique among +// all subscriptions for the same Client. +func (s *Subscription) Id() string { + return s.id +} + +// Destination for which the subscription applies. +func (s *Subscription) Destination() string { + return s.destination +} + +// AckMode returns the Acknowledgement mode specified when the +// subscription was created. +func (s *Subscription) AckMode() AckMode { + return s.ackMode +} + +// Active returns whether the subscription is still active. +// Returns false if the subscription has been unsubscribed. +func (s *Subscription) Active() bool { + return atomic.LoadInt32(&s.state) == subStateActive +} + +// Unsubscribes and closes the channel C. +func (s *Subscription) Unsubscribe(opts ...func(*frame.Frame) error) error { + // transition to the "closing" state + if !atomic.CompareAndSwapInt32(&s.state, subStateActive, subStateClosing) { + return ErrCompletedSubscription + } + + f := frame.New(frame.UNSUBSCRIBE, frame.Id, s.id) + + for _, opt := range opts { + if opt == nil { + return ErrNilOption + } + err := opt(f) + if err != nil { + return err + } + } + + s.conn.sendFrame(f) + + // UNSUBSCRIBE is a bit weird in that it is tagged with a "receipt" header + // on the I/O goroutine, so the above call to sendFrame() will not wait + // for the resulting RECEIPT. + // + // We don't want to interfere with `s.C` since we might be "stealing" + // MESSAGEs or ERRORs from another goroutine, so use a sync.Cond to + // wait for the terminal state transition instead. + s.closeMutex.Lock() + for atomic.LoadInt32(&s.state) != subStateClosed { + s.closeCond.Wait() + } + s.closeMutex.Unlock() + return nil +} + +// Read a message from the subscription. This is a convenience +// method: many callers will prefer to read from the channel C +// directly. +func (s *Subscription) Read() (*Message, error) { + if !s.Active() { + return nil, ErrCompletedSubscription + } + msg, ok := <-s.C + if !ok { + return nil, ErrCompletedSubscription + } + if msg.Err != nil { + return nil, msg.Err + } + return msg, nil +} + +func (s *Subscription) closeChannel(msg *Message) { + if msg != nil { + s.C <- msg + } + atomic.StoreInt32(&s.state, subStateClosed) + close(s.C) + s.closeCond.Broadcast() +} + +func (s *Subscription) readLoop(ch chan *frame.Frame) { + for { + f, ok := <-ch + if !ok { + state := atomic.LoadInt32(&s.state) + if state == subStateActive || state == subStateClosing { + msg := &Message{ + Err: &Error{ + Message: fmt.Sprintf("Subscription %s: %s: channel read failed", s.id, s.destination), + }, + } + s.closeChannel(msg) + } + return + } + + if f.Command == frame.MESSAGE { + destination := f.Header.Get(frame.Destination) + contentType := f.Header.Get(frame.ContentType) + msg := &Message{ + Destination: destination, + ContentType: contentType, + Conn: s.conn, + Subscription: s, + Header: f.Header, + Body: f.Body, + } + s.C <- msg + } else if f.Command == frame.ERROR { + state := atomic.LoadInt32(&s.state) + if state == subStateActive || state == subStateClosing { + message, _ := f.Header.Contains(frame.Message) + text := fmt.Sprintf("Subscription %s: %s: ERROR message:%s", + s.id, + s.destination, + message) + s.conn.log.Info(text) + contentType := f.Header.Get(frame.ContentType) + msg := &Message{ + Err: &Error{ + Message: f.Header.Get(frame.Message), + Frame: f, + }, + ContentType: contentType, + Conn: s.conn, + Subscription: s, + Header: f.Header, + Body: f.Body, + } + s.closeChannel(msg) + } + return + } else if f.Command == frame.RECEIPT { + state := atomic.LoadInt32(&s.state) + if state == subStateActive || state == subStateClosing { + s.closeChannel(nil) + } + return + } else { + s.conn.log.Infof("Subscription %s: %s: unsupported frame type: %+v", s.id, s.destination, f) + } + } +} diff --git a/backend/services/stomp/testutil/fake_conn.go b/backend/services/stomp/testutil/fake_conn.go new file mode 100644 index 0000000..c03659d --- /dev/null +++ b/backend/services/stomp/testutil/fake_conn.go @@ -0,0 +1,113 @@ +package testutil + +import ( + "errors" + . "gopkg.in/check.v1" + "io" + "net" + "time" +) + +type FakeAddr struct { + Value string +} + +func (addr *FakeAddr) Network() string { + return "fake" +} + +func (addr *FakeAddr) String() string { + return addr.Value +} + +// FakeConn is a fake connection used for testing. It implements +// the net.Conn interface and is useful for simulating I/O between +// STOMP clients and a STOMP server. +type FakeConn struct { + C *C + writer io.WriteCloser + reader io.ReadCloser + localAddr net.Addr + remoteAddr net.Addr +} + +var ( + ErrClosing = errors.New("use of closed network connection") +) + +// NewFakeConn returns a pair of fake connections suitable for +// testing. +func NewFakeConn(c *C) (client *FakeConn, server *FakeConn) { + clientReader, serverWriter := io.Pipe() + serverReader, clientWriter := io.Pipe() + clientAddr := &FakeAddr{Value: "the-client:123"} + serverAddr := &FakeAddr{Value: "the-server:456"} + + clientConn := &FakeConn{ + C: c, + reader: clientReader, + writer: clientWriter, + localAddr: clientAddr, + remoteAddr: serverAddr, + } + + serverConn := &FakeConn{ + C: c, + reader: serverReader, + writer: serverWriter, + localAddr: serverAddr, + remoteAddr: clientAddr, + } + + return clientConn, serverConn +} + +func (fc *FakeConn) Read(p []byte) (n int, err error) { + n, err = fc.reader.Read(p) + return +} + +func (fc *FakeConn) Write(p []byte) (n int, err error) { + return fc.writer.Write(p) +} + +func (fc *FakeConn) Close() error { + err1 := fc.reader.Close() + err2 := fc.writer.Close() + + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + return nil +} + +func (fc *FakeConn) LocalAddr() net.Addr { + return fc.localAddr +} + +func (fc *FakeConn) RemoteAddr() net.Addr { + return fc.remoteAddr +} + +func (fc *FakeConn) SetLocalAddr(addr net.Addr) { + fc.localAddr = addr +} + +func (fc *FakeConn) SetRemoteAddr(addr net.Addr) { + fc.remoteAddr = addr +} + +func (fc *FakeConn) SetDeadline(t time.Time) error { + panic("not implemented") +} + +func (fc *FakeConn) SetReadDeadline(t time.Time) error { + panic("not implemented") +} + +func (fc *FakeConn) SetWriteDeadline(t time.Time) error { + panic("not implemented") +} diff --git a/backend/services/stomp/testutil/fake_conn_test.go b/backend/services/stomp/testutil/fake_conn_test.go new file mode 100644 index 0000000..0ca9273 --- /dev/null +++ b/backend/services/stomp/testutil/fake_conn_test.go @@ -0,0 +1,61 @@ +package testutil + +import ( + . "gopkg.in/check.v1" + "testing" +) + +func TestTestUtil(t *testing.T) { + TestingT(t) +} + +type FakeConnSuite struct{} + +var _ = Suite(&FakeConnSuite{}) + +func (s *FakeConnSuite) TestFakeConn(c *C) { + //c.Skip("temporary") + fc1, fc2 := NewFakeConn(c) + + one := []byte{1, 2, 3, 4} + two := []byte{5, 6, 7, 8, 9, 10, 11, 12, 13} + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + + rx1 := make([]byte, 6) + n, err := fc2.Read(rx1) + c.Assert(n, Equals, 4) + c.Assert(err, IsNil) + c.Assert(rx1[0:n], DeepEquals, one) + + rx2 := make([]byte, 5) + n, err = fc2.Read(rx2) + c.Assert(n, Equals, 5) + c.Assert(err, IsNil) + c.Assert(rx2, DeepEquals, []byte{5, 6, 7, 8, 9}) + + rx3 := make([]byte, 10) + n, err = fc2.Read(rx3) + c.Assert(n, Equals, 4) + c.Assert(err, IsNil) + c.Assert(rx3[0:n], DeepEquals, []byte{10, 11, 12, 13}) + }() + + c.Assert(fc1.C, Equals, c) + c.Assert(fc2.C, Equals, c) + + n, err := fc1.Write(one) + c.Assert(n, Equals, 4) + c.Assert(err, IsNil) + + n, err = fc1.Write(two) + c.Assert(n, Equals, len(two)) + c.Assert(err, IsNil) + + <-stop +} diff --git a/backend/services/stomp/testutil/mock_logger.go b/backend/services/stomp/testutil/mock_logger.go new file mode 100644 index 0000000..4bd99e0 --- /dev/null +++ b/backend/services/stomp/testutil/mock_logger.go @@ -0,0 +1,150 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./logger.go + +// Package testutil is a generated GoMock package. +package testutil + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(message string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", message) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), message) +} + +// Debugf mocks base method. +func (m *MockLogger) Debugf(format string, value ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range value { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf. +func (mr *MockLoggerMockRecorder) Debugf(format interface{}, value ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, value...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Error mocks base method. +func (m *MockLogger) Error(message string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", message) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), message) +} + +// Errorf mocks base method. +func (m *MockLogger) Errorf(format string, value ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range value { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Errorf", varargs...) +} + +// Errorf indicates an expected call of Errorf. +func (mr *MockLoggerMockRecorder) Errorf(format interface{}, value ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, value...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...) +} + +// Info mocks base method. +func (m *MockLogger) Info(message string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", message) +} + +// Info indicates an expected call of Info. +func (mr *MockLoggerMockRecorder) Info(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), message) +} + +// Infof mocks base method. +func (m *MockLogger) Infof(format string, value ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range value { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Infof", varargs...) +} + +// Infof indicates an expected call of Infof. +func (mr *MockLoggerMockRecorder) Infof(format interface{}, value ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, value...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...) +} + +// Warning mocks base method. +func (m *MockLogger) Warning(message string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Warning", message) +} + +// Warning indicates an expected call of Warning. +func (mr *MockLoggerMockRecorder) Warning(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warning", reflect.TypeOf((*MockLogger)(nil).Warning), message) +} + +// Warningf mocks base method. +func (m *MockLogger) Warningf(format string, value ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range value { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warningf", varargs...) +} + +// Warningf indicates an expected call of Warningf. +func (mr *MockLoggerMockRecorder) Warningf(format interface{}, value ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, value...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warningf", reflect.TypeOf((*MockLogger)(nil).Warningf), varargs...) +} diff --git a/backend/services/stomp/testutil/testutil.go b/backend/services/stomp/testutil/testutil.go new file mode 100644 index 0000000..ba625e6 --- /dev/null +++ b/backend/services/stomp/testutil/testutil.go @@ -0,0 +1,5 @@ +/* +Package testutil contains operations useful for testing. In particular, +it provides fake connections useful for testing client/server interactions. +*/ +package testutil diff --git a/backend/services/stomp/transaction.go b/backend/services/stomp/transaction.go new file mode 100644 index 0000000..0a8398f --- /dev/null +++ b/backend/services/stomp/transaction.go @@ -0,0 +1,178 @@ +package stomp + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// A Transaction applies to the sending of messages to the STOMP server, +// and the acknowledgement of messages received from the STOMP server. +// All messages sent and and acknowledged in the context of a transaction +// are processed atomically by the STOMP server. +// +// Transactions are committed with the Commit method. When a transaction is +// committed, all sent messages, acknowledgements and negative acknowledgements, +// are processed by the STOMP server. Alternatively transactions can be aborted, +// in which case all sent messages, acknowledgements and negative +// acknowledgements are discarded by the STOMP server. +type Transaction struct { + id string + conn *Conn + completed bool +} + +// Id returns the unique identifier for the transaction. +func (tx *Transaction) Id() string { + return tx.id +} + +// Conn returns the connection associated with this transaction. +func (tx *Transaction) Conn() *Conn { + return tx.conn +} + +// Abort will abort the transaction. Any calls to Send, SendWithReceipt, +// Ack and Nack on this transaction will be discarded. +// This function does not wait for the server to process the ABORT frame. +// See AbortWithReceipt if you want to ensure the ABORT is processed. +func (tx *Transaction) Abort() error { + return tx.abort(false) +} + +// Abort will abort the transaction. Any calls to Send, SendWithReceipt, +// Ack and Nack on this transaction will be discarded. +func (tx *Transaction) AbortWithReceipt() error { + return tx.abort(true) +} + +func (tx *Transaction) abort(receipt bool) error { + if tx.completed { + return ErrCompletedTransaction + } + + f := frame.New(frame.ABORT, frame.Transaction, tx.id) + + if receipt { + id := allocateId() + f.Header.Set(frame.Receipt, id) + } + + err := tx.conn.sendFrame(f) + if err != nil { + return err + } + tx.completed = true + + return nil +} + +// Commit will commit the transaction. All messages and acknowledgements +// sent to the STOMP server on this transaction will be processed atomically. +// This function does not wait for the server to process the COMMIT frame. +// See CommitWithReceipt if you want to ensure the COMMIT is processed. +func (tx *Transaction) Commit() error { + return tx.commit(false) +} + +// Commit will commit the transaction. All messages and acknowledgements +// sent to the STOMP server on this transaction will be processed atomically. +func (tx *Transaction) CommitWithReceipt() error { + return tx.commit(true) +} + +func (tx *Transaction) commit(receipt bool) error { + if tx.completed { + return ErrCompletedTransaction + } + + f := frame.New(frame.COMMIT, frame.Transaction, tx.id) + + if receipt { + id := allocateId() + f.Header.Set(frame.Receipt, id) + } + + err := tx.conn.sendFrame(f) + if err != nil { + return err + } + tx.completed = true + + return nil +} + +// Send sends a message to the STOMP server as part of a transaction. The server will not process the +// message until the transaction is committed. +// This method returns without confirming that the STOMP server has received the message. If the STOMP server +// does fail to receive the message for any reason, the connection will close. +// +// The content type should be specified, according to the STOMP specification, but if contentType is an empty +// string, the message will be delivered without a content type header entry. The body array contains the +// message body, and its content should be consistent with the specified content type. +// +// TODO: document opts +func (tx *Transaction) Send(destination, contentType string, body []byte, opts ...func(*frame.Frame) error) error { + if tx.completed { + return ErrCompletedTransaction + } + + f, err := createSendFrame(destination, contentType, body, opts) + if err != nil { + return err + } + + f.Header.Set(frame.Transaction, tx.id) + return tx.conn.sendFrame(f) +} + +// Ack sends an acknowledgement for the message to the server. The STOMP +// server will not process the acknowledgement until the transaction +// has been committed. If the subscription has an AckMode of AckAuto, calling +// this function has no effect. +func (tx *Transaction) Ack(msg *Message) error { + if tx.completed { + return ErrCompletedTransaction + } + + f, err := tx.conn.createAckNackFrame(msg, true) + if err != nil { + return err + } + + if f != nil { + f.Header.Set(frame.Transaction, tx.id) + err := tx.conn.sendFrame(f) + if err != nil { + return err + } + } + + return nil +} + +// Nack sends a negative acknowledgement for the message to the server, +// indicating that this client cannot or will not process the message and +// that it should be processed elsewhere. The STOMP server will not process +// the negative acknowledgement until the transaction has been committed. +// It is an error to call this method if the subscription has an AckMode +// of AckAuto, because the STOMP server will not be expecting any kind +// of acknowledgement (positive or negative) for this message. +func (tx *Transaction) Nack(msg *Message) error { + if tx.completed { + return ErrCompletedTransaction + } + + f, err := tx.conn.createAckNackFrame(msg, false) + if err != nil { + return err + } + + if f != nil { + f.Header.Set(frame.Transaction, tx.id) + err := tx.conn.sendFrame(f) + if err != nil { + return err + } + } + + return nil +} diff --git a/backend/services/stomp/validator.go b/backend/services/stomp/validator.go new file mode 100644 index 0000000..8e64a2c --- /dev/null +++ b/backend/services/stomp/validator.go @@ -0,0 +1,21 @@ +package stomp + +import ( + "github.com/go-stomp/stomp/v3/frame" +) + +// Validator is an interface for validating STOMP frames. +type Validator interface { + // Validate returns nil if the frame is valid, or an error if not valid. + Validate(f *frame.Frame) error +} + +func NewValidator(version Version) Validator { + return validatorNull{} +} + +type validatorNull struct{} + +func (v validatorNull) Validate(f *frame.Frame) error { + return nil +} diff --git a/backend/services/stomp/version.go b/backend/services/stomp/version.go new file mode 100644 index 0000000..d14296f --- /dev/null +++ b/backend/services/stomp/version.go @@ -0,0 +1,40 @@ +package stomp + +// Version is the STOMP protocol version. +type Version string + +const ( + V10 Version = "1.0" + V11 Version = "1.1" + V12 Version = "1.2" +) + +// String returns a string representation of the STOMP version. +func (v Version) String() string { + return string(v) +} + +// CheckSupported is used to determine whether a particular STOMP +// version is supported by this library. Returns nil if the version is +// supported, or ErrUnsupportedVersion if not supported. +func (v Version) CheckSupported() error { + switch v { + case V10, V11, V12: + return nil + } + return ErrUnsupportedVersion +} + +// SupportsNack indicates whether this version of the STOMP protocol +// supports use of the NACK command. +func (v Version) SupportsNack() bool { + switch v { + case V10: + return false + case V11, V12: + return true + } + + // unknown version + return false +} diff --git a/backend/services/stomp/version_test.go b/backend/services/stomp/version_test.go new file mode 100644 index 0000000..c9de88f --- /dev/null +++ b/backend/services/stomp/version_test.go @@ -0,0 +1,79 @@ +package stomp_test + +import ( + "testing" + + "github.com/go-stomp/stomp/v3" +) + +func TestSupportsNack(t *testing.T) { + testCases := []struct { + Version stomp.Version + SupportsNack bool + }{ + { + Version: stomp.Version("1.0"), + SupportsNack: false, + }, + { + Version: stomp.Version("1.1"), + SupportsNack: true, + }, + { + Version: stomp.Version("1.2"), + SupportsNack: true, + }, + { + Version: stomp.Version("xxx"), + SupportsNack: false, + }, + } + + for _, testCase := range testCases { + version := testCase.Version + expected := testCase.SupportsNack + actual := version.SupportsNack() + if expected != actual { + t.Errorf("Version %v: SupportsNack: expected %v, actual %v", + version, expected, actual) + } + + } + +} + +func TestCheckSupported(t *testing.T) { + testCases := []struct { + Version stomp.Version + Err error + }{ + { + Version: stomp.Version("1.0"), + Err: nil, + }, + { + Version: stomp.Version("1.1"), + Err: nil, + }, + { + Version: stomp.Version("1.2"), + Err: nil, + }, + { + Version: stomp.Version("2.2"), + Err: stomp.ErrUnsupportedVersion, + }, + } + + for _, testCase := range testCases { + version := testCase.Version + expected := testCase.Err + actual := version.CheckSupported() + if expected != actual { + t.Errorf("Version %v: CheckSupported: expected %v, actual %v", + version, expected, actual) + } + + } + +}