refactor: mqtt broker mochi import packages + connack packet custom

This commit is contained in:
Leandro Antônio Farias Machado 2023-07-26 11:26:13 -03:00
parent 1a1d6abcc1
commit 51acf7569f
1003 changed files with 29 additions and 378906 deletions

View File

@ -1,43 +0,0 @@
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.19
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov

View File

@ -1,103 +0,0 @@
linters:
disable-all: false
fix: false # Fix found issues (if it's supported by the linter).
enable:
# - asasalint
# - asciicheck
# - bidichk
# - bodyclose
# - containedctx
# - contextcheck
#- cyclop
# - deadcode
- decorder
# - depguard
# - dogsled
# - dupl
- durationcheck
# - errchkjson
# - errname
- errorlint
# - execinquery
# - exhaustive
# - exhaustruct
# - exportloopref
#- forcetypeassert
#- forbidigo
#- funlen
#- gci
# - gochecknoglobals
# - gochecknoinits
# - gocognit
# - goconst
# - gocritic
- gocyclo
- godot
# - godox
# - goerr113
# - gofmt
# - gofumpt
# - goheader
- goimports
# - golint
# - gomnd
# - gomoddirectives
# - gomodguard
# - goprintffuncname
- gosec
- gosimple
- govet
# - grouper
# - ifshort
- importas
- ineffassign
# - interfacebloat
# - interfacer
# - ireturn
# - lll
# - maintidx
# - makezero
- maligned
- misspell
# - nakedret
# - nestif
# - nilerr
# - nilnil
# - nlreturn
# - noctx
# - nolintlint
# - nonamedreturns
# - nosnakecase
# - nosprintfhostport
# - paralleltest
# - prealloc
# - predeclared
# - promlinter
- reassign
# - revive
# - rowserrcheck
# - scopelint
# - sqlclosecheck
# - staticcheck
# - structcheck
# - stylecheck
# - tagliatelle
# - tenv
# - testpackage
# - thelper
- tparallel
# - typecheck
- unconvert
- unparam
- unused
- usestdlibvars
# - varcheck
# - varnamelen
- wastedassign
- whitespace
# - wrapcheck
# - wsl
disable:
- errcheck

View File

@ -1,31 +0,0 @@
FROM golang:1.19.0-alpine3.15 AS builder
RUN apk update
RUN apk add git
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . ./
RUN go build -o /app/mochi ./cmd
FROM alpine
WORKDIR /
COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ]

View File

@ -1,22 +0,0 @@
The MIT License (MIT)
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,5 +0,0 @@
This broker is an implementation of mochi. I've to forked it to customize CONNACK packet userProperty, although mochi lib might have a better approach to do it.
To run this project you might have Go compiler in your machine, and inside cmd folder there is a run.sh script, which runs the project with the right arguments; also inside the same folder is the auth.json file, that carries configs of RBAC.
![img.png](img.png)

View File

@ -1,568 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-co/mqtt/v2/packets"
)
const (
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
)
// ReadFn is the function signature for the function used for reading and processing new packets.
type ReadFn func(*Client, packets.Packet) error
// Clients contains a map of the clients known by the broker.
type Clients struct {
internal map[string]*Client // clients known by the broker, keyed on client id.
sync.RWMutex
}
// NewClients returns an instance of Clients.
func NewClients() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
defer cl.Unlock()
cl.internal[val.ID] = val
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
m := map[string]*Client{}
for k, v := range cl.internal {
m[k] = v
}
return m
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
defer cl.RUnlock()
val, ok := cl.internal[id]
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
defer cl.RUnlock()
val := len(cl.internal)
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
defer cl.Unlock()
delete(cl.internal, id)
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
cl.RLock()
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
clients = append(clients, client)
}
}
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the clinet
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
}
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // client is an inline programmetic client
}
// ClientProperties contains the properties which define the client behaviour.
type ClientProperties struct {
Props packets.Properties
Will Will
Username []byte
ProtocolVersion byte
Clean bool
}
// Will contains the last will and testament details for a client connection.
type Will struct {
Payload []byte // -
User []packets.UserProperty // -
TopicName string // -
Flag uint32 // 0,1
WillDelayInterval uint32 // -
Qos byte // -
Retain bool // -
}
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
cl := &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
ops: o,
}
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReadWriter(
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
),
Remote: c.RemoteAddr().String(),
}
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for pk := range cl.State.outbound {
if err := cl.WritePacket(*pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
cl.Properties.ProtocolVersion = pk.ProtocolVersion
cl.Properties.Username = pk.Connect.Username
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
Retain: pk.Connect.WillRetain,
Payload: pk.Connect.WillPayload,
TopicName: pk.Connect.WillTopic,
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
User: pk.Connect.WillProperties.User,
}
if pk.Properties.SessionExpiryIntervalFlag &&
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
}
if pk.Connect.WillFlag {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
cl.refreshDeadline(cl.State.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
// NextPacketID returns the next available (unused) packet id for the client.
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i
overflowed := false
for {
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
func (cl *Client) ResendInflightMessages(force bool) error {
if cl.State.Inflight.Len() == 0 {
return nil
}
for _, tk := range cl.State.Inflight.GetAll(false) {
if tk.FixedHeader.Type == packets.Publish {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
}
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosComplete(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
return nil
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
// Read reads incoming packets from the connected client and transforms them into
// packets to be handled by the packetHandler.
func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
return nil
}
cl.refreshDeadline(cl.State.keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
return ErrConnectionClosed
}
b, err := cl.Net.bconn.ReadByte()
if err != nil {
return err
}
err = fh.Decode(b)
if err != nil {
return err
}
var bu int
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
if err != nil {
return err
}
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
pk.FixedHeader = *fh
p := make([]byte, pk.FixedHeader.Remaining)
n, err := io.ReadFull(cl.Net.bconn, p)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Disconnect:
err = pk.DisconnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Auth:
err = pk.AuthDecode(px)
default:
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
}
if err != nil {
return pk, err
}
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
pk.ProtocolVersion = cl.Properties.ProtocolVersion
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
}
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
var err error
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
case packets.Auth:
err = pk.AuthEncode(buf)
default:
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
}
if err != nil {
return err
}
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.Conn)
if err != nil {
return err
}
atomic.AddInt64(&cl.ops.info.BytesSent, n)
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
if pk.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
}
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
return err
}

View File

@ -1,745 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
cl.ID = "mochi"
cl.State.Inflight.maximumSendQuota = 5
cl.State.Inflight.sendQuota = 5
cl.State.Inflight.maximumReceiveQuota = 10
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
func TestNewInflights(t *testing.T) {
require.NotNil(t, NewInflights().internal)
}
func TestNewClients(t *testing.T) {
cl := NewClients()
require.NotNil(t, cl.internal)
}
func TestClientsAdd(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
}
func TestClientsGet(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
client, ok := cl.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, "t1", client.ID)
}
func TestClientsGetAll(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Contains(t, cl.internal, "t3")
clients := cl.GetAll()
require.Len(t, clients, 3)
}
func TestClientsLen(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Equal(t, 2, cl.Len())
}
func TestClientsDelete(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
cl.Delete("t1")
_, ok := cl.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, cl.internal["t1"])
}
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
clients := cl.GetByListener("tcp1")
require.NotEmpty(t, clients)
require.Equal(t, 1, len(clients))
require.Equal(t, "tcp1", clients[0].Net.Listener)
}
func TestNewClient(t *testing.T) {
cl, _, _ := newTestClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillTopic: "lwt",
WillPayload: []byte("lol gg"),
WillQos: 1,
WillRetain: false,
},
Properties: packets.Properties{
ReceiveMaximum: uint16(5),
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillProperties: packets.Properties{
WillDelayInterval: 200,
},
},
Properties: packets.Properties{
SessionExpiryInterval: 100,
SessionExpiryIntervalFlag: true,
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newTestClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(2), i)
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newTestClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(3), i)
// Skip over overflow
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
atomic.StoreUint32(&cl.State.packetID, 65534)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
go func() {
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, 0, cl.State.Inflight.Len())
require.Equal(t, pk1.RawBytes, buf)
}
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newTestClient()
r.Close()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
err := cl.ResendInflightMessages(true)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0x00})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.Equal(t, io.EOF, err)
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
}()
var pks []packets.Packet
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
pks = append(pks, pk)
return nil
})
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 2, len(pks))
require.Equal(t, []packets.Packet{
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 18,
},
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
},
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
},
}, pks)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
return nil
})
}()
require.NoError(t, <-o)
}
func TestClientStop(t *testing.T) {
cl, _, _ := newTestClient()
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
})
r.Close()
}()
cl.Net.bconn = nil
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, ErrConnectionClosed, err)
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
}()
err := cl.Read(func(cl *Client, pk packets.Packet) error {
return errors.New("test")
})
require.Error(t, err)
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
pk, err := cl.ReadPacket(fh)
require.NoError(t, err)
require.NotNil(t, pk)
require.Equal(t, packets.Packet{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
}, pk)
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
tt := tx // avoid data race
t.Run(tt.Desc, func(t *testing.T) {
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
go func() {
r.Write(tt.RawBytes)
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
if tt.Packet.ProtocolVersion == 5 {
cl.Properties.ProtocolVersion = 5
} else {
cl.Properties.ProtocolVersion = 0
}
pk, err := cl.ReadPacket(fh)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
}
})
}
}
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
o := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
o <- buf
}()
err := cl.WritePacket(*tt.Packet)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
cl.Stop(errClientStop)
time.Sleep(time.Millisecond * 1)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
require.True(t,
errors.Is(err, errClientStop) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe))
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
}
}
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
require.Error(t, err)
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Type: 0,
Remaining: 11,
})
require.Error(t, err)
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
})
require.Error(t, err)
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newTestClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
require.Equal(t, ErrConnectionClosed, err)
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}
var (
pkTable = []packets.TPacketCase{
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
packets.TPacketData[packets.Puback].Get(packets.TPuback),
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
packets.TPacketData[packets.Suback].Get(packets.TSuback),
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
packets.TPacketData[packets.Auth].Get(packets.TAuth),
}
)

View File

@ -5,6 +5,7 @@ import (
"crypto/tls"
"flag"
rv8 "github.com/go-redis/redis/v8"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
@ -15,42 +16,12 @@ import (
"strings"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
var (
// testCertificate = []byte(`-----BEGIN CERTIFICATE-----
//MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
//VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
//BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
//VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
//DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
//AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
//OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
//MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
//gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
//qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
//zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
//-----END CERTIFICATE-----`)
//
// testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
//MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
//FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
//rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
//AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
//UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
//n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
//mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
//INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
//AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
///F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
//WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
//w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
//OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
//-----END RSA PRIVATE KEY-----`)
//TODO: create custom mqtt server options
server = mqtt.New(&mqtt.Options{
//Capabilities: &mqtt.Capabilities{
// ServerKeepAlive: 10000,
@ -233,6 +204,7 @@ func (h *MyHook) Provides(b byte) bool {
mqtt.OnSubscribed,
mqtt.OnDisconnect,
mqtt.OnClientExpired,
mqtt.OnPacketEncode,
}, []byte{b})
}
@ -284,3 +256,15 @@ func (h *MyHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []
}
}
func (h *MyHook) OnPacketEncode(cl *mqtt.Client, pk packets.Packet) packets.Packet {
var clUser string
if len(cl.Properties.Props.User) > 0 {
clUser = cl.Properties.Props.User[0].Val
}
if pk.FixedHeader.Type == packets.Connack {
pk.Properties.User = []packets.UserProperty{{Key: "subscribe-topic", Val: "oktopus/v1/agent/" + clUser}}
}
return pk
}

View File

@ -1,83 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
authRules := &auth.Ledger{
Auth: auth.AuthRules{ // Auth disallows all by default
{Username: "peach", Password: "password1", Allow: true},
{Username: "melon", Password: "password2", Allow: true},
{Remote: "127.0.0.1:*", Allow: true},
{Remote: "localhost:*", Allow: true},
},
ACL: auth.ACLRules{ // ACL allows all by default
{Remote: "127.0.0.1:*"}, // local superuser allow all
{
// user melon can read and write to their own topic
Username: "melon", Filters: auth.Filters{
"melon/#": auth.ReadWrite,
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
},
},
{
// Otherwise, no clients have publishing permissions
Filters: auth.Filters{
"#": auth.ReadOnly,
"updates/#": auth.Deny,
},
},
},
}
// you may also find this useful...
// d, _ := authRules.ToYAML()
// d, _ := authRules.ToJSON()
// fmt.Println(string(d))
server := mqtt.New(nil)
err := server.AddHook(new(auth.Hook), &auth.Options{
Ledger: authRules,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,52 +0,0 @@
{
"auth": [
{
"username": "leandro",
"password": "leandro",
"allow": true
},
{
"username": "steve",
"password": "steve",
"allow": true
},
{
"username": "root",
"password": "root",
"allow": true
},
{
"remote": "*",
"allow": false
},
{
"remote": "*",
"allow": false
}
],
"acl": [
{
"remote": "*"
},
{
"username": "leandro",
"filters": {
"oktopus/+/agent/+": 1,
"oktopus/+/controller/+": 2
}
},
{
"username": "steve",
"filters": {
"oktopus/+/agent/+": 1,
"oktopus/+/controller/+": 2
}
},
{
"username": "root",
"filters": {
"#": 3
}
}
]
}

View File

@ -1,21 +0,0 @@
auth:
- username: peach
password: password1
allow: true
- username: melon
password: password2
allow: true
# - remote: 127.0.0.1:*
# allow: true
# - remote: localhost:*
# allow: true
acl:
# 0 = deny, 1 = read only, 2 = write only, 3 = read and write
- remote: 127.0.0.1:*
- username: melon
filters:
melon/#: 3
updates/#: 2
- filters:
'#': 1
updates/#: 0

View File

@ -1,65 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
// You can also run from top-level server.go folder:
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json
path := flag.String("path", "auth.yaml", "path to data auth file")
flag.Parse()
// Get ledger from yaml file
data, err := os.ReadFile(*path)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(nil)
err = server.AddHook(new(auth.Hook), &auth.Options{
Data: data, // build ledger from byte slice, yaml or json
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,52 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,62 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/debug"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/rs/zerolog"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,143 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"bytes"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/mochi-co/mqtt/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(ExampleHook), map[string]any{})
if err != nil {
log.Fatal(err)
}
// Start the server
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Demonstration of directly publishing messages to a topic via the
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
cl := server.NewClient(nil, "local", "inline", true)
for range time.Tick(time.Second * 1) {
err := server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("injected scheduled message"),
})
if err != nil {
server.Log.Error().Err(err).Msg("server.InjectPacket")
}
server.Log.Info().Msgf("main.go injected packet to direct/publish")
}
}()
// There is also a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 5) {
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
if err != nil {
server.Log.Error().Err(err).Msg("server.Publish")
}
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}
type ExampleHook struct {
mqtt.HookBase
}
func (h *ExampleHook) ID() string {
return "events-example"
}
func (h *ExampleHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnect,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnPublished,
mqtt.OnPublish,
}, []byte{b})
}
func (h *ExampleHook) Init(config any) error {
h.Log.Info().Msg("initialised")
return nil
}
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
}
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
}
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
}
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
}
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
pkx := pk
if string(pk.Payload) == "hello" {
pkx.Payload = []byte("hello world")
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
}
return pkx, nil
}
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
}

View File

@ -1,74 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"bytes"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/mochi-co/mqtt/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}
type pahoAuthHook struct {
mqtt.HookBase
}
func (h *pahoAuthHook) ID() string {
return "allow-all-auth"
}
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}

View File

@ -1,59 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/badger"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
badgerPath := ".badger"
defer os.RemoveAll(badgerPath) // remove the example badger files at the end
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(badger.Hook), &badger.Options{
Path: badgerPath,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,60 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/bolt"
"github.com/mochi-co/mqtt/v2/listeners"
"go.etcd.io/bbolt"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), &bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,65 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/rs/zerolog"
rv8 "github.com/go-redis/redis/v8"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(redis.Hook), &redis.Options{
Options: &rv8.Options{
Addr: "localhost:6379", // default redis address
Password: "", // your password
DB: 0, // your redis db
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,58 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
// An example of configuring various server options...
options := &mqtt.Options{
// InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
}
server := mqtt.New(options)
// For security reasons, the default implementation disallows all connections.
// If you want to allow all connections, you must specifically allow it.
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,117 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"crypto/tls"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
var (
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{
TLSConfig: tlsConfig,
})
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{
TLSConfig: tlsConfig,
})
err = server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
TLSConfig: tlsConfig,
}, nil)
err = server.AddListener(stats)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,47 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
ws := listeners.NewWebsocket("ws1", ":1882", nil)
err := server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -1,40 +1,20 @@
module github.com/mochi-co/mqtt/v2
module broker
go 1.19
go 1.18
require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/rs/xid v1.4.0
github.com/rs/zerolog v1.28.0
github.com/stretchr/testify v1.7.1
github.com/timshannon/badgerhold v1.0.0
go.etcd.io/bbolt v1.3.5
gopkg.in/yaml.v3 v3.0.1
github.com/mochi-co/mqtt/v2 v2.2.16
github.com/rs/zerolog v1.29.1
)
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/golang/protobuf v1.5.0 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
golang.org/x/net v0.7.0 // indirect
github.com/rs/xid v1.4.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,143 +1,44 @@
github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M=
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo=
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mochi-co/mqtt/v2 v2.2.16 h1:CBqbxFhExzASNjj4BjSei0hYY1F5N5IeDqNVhjN+tp8=
github.com/mochi-co/mqtt/v2 v2.2.16/go.mod h1:MDMTThFgWj/LjJ6wc51bP5l4xnJG/ahpc9tR9vZVf8Q=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/timshannon/badgerhold v1.0.0 h1:LtqnDRVP7294FWRiZCIfQa6Tt0bGmlzbO8c364QC2Y8=
github.com/timshannon/badgerhold v1.0.0/go.mod h1:Vv2Jj0PAfzqViEpGvJzLP8PY07x1iXLgKRuLY7bqPOE=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,794 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, thedevop
package mqtt
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
)
const (
SetOptions byte = iota
OnSysInfoTick
OnStarted
OnStopped
OnConnectAuthenticate
OnACLCheck
OnConnect
OnSessionEstablished
OnDisconnect
OnAuthPacket
OnPacketRead
OnPacketEncode
OnPacketSent
OnPacketProcessed
OnSubscribe
OnSubscribed
OnSelectSubscribers
OnUnsubscribe
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnQosPublish
OnQosComplete
OnQosDropped
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
StoredClients
StoredSubscriptions
StoredInflightMessages
StoredRetainedMessages
StoredSysInfo
)
var (
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
ErrInvalidConfigType = errors.New("invalid config type provided")
)
// Hook provides an interface of handlers for different events which occur
// during the lifecycle of the broker.
type Hook interface {
ID() string
Provides(b byte) bool
Init(config any) error
Stop() error
SetOpts(l *zerolog.Logger, o *HookOptions)
OnStarted()
OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
OnACLCheck(cl *Client, topic string, write bool) bool
OnSysInfoTick(*system.Info)
OnConnect(cl *Client, pk packets.Packet)
OnSessionEstablished(cl *Client, pk packets.Packet)
OnDisconnect(cl *Client, err error, expire bool)
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
StoredRetainedMessages() ([]storage.Message, error)
StoredSysInfo() (storage.SystemInfo, error)
}
// HookOptions contains values which are inherited from the server on initialisation.
type HookOptions struct {
Capabilities *Capabilities
}
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
func (h *Hooks) Len() int64 {
return atomic.LoadInt64(&h.qty)
}
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
}
}
}
return false
}
// Add adds and initializes a new hook.
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.GetAll() {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
}
h.wg.Done()
}
}()
h.wg.Wait()
}
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
}
}
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
}
}
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
}
}
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
}
}
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
}
}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
}
}
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
return pk, err
}
pkx = npk
}
}
return
}
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
}
return pk
}
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
}
}
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
}
}
// OnSubscribe is called when a client subscribes to one or more filters. This method
// differs from OnSubscribed in that it allows you to modify the subscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
}
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
}
}
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
// shared subscription subscribers have been selected. This hook can be used to programmatically
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
}
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
}
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
}
}
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
}
}
// OnQosComplete is called when the Qos flow for a message has been completed.
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
}
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
continue
}
will = mlwt
}
}
return will
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
}
}
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
}
}
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
}
}
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
return v, err
}
if v.Version != "" {
return v, nil
}
}
}
return
}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
// An implementation of this method MUST be used to allow or deny access to the
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
}
}
}
return false
}
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
// An implementation of this method MUST be used to allow or deny access to the
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
}
}
}
return false
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
Hook
Log *zerolog.Logger
Opts *HookOptions
}
// ID returns the ID of the hook.
func (h *HookBase) ID() string {
return "base"
}
// Provides indicates which methods a hook provides. The default is none - this method
// should be overridden by the embedding hook.
func (h *HookBase) Provides(b byte) bool {
return false
}
// Init performs any pre-start initializations for the hook, such as connecting to databases
// or opening files.
func (h *HookBase) Init(config any) error {
return nil
}
// SetOpts is called by the server to propagate internal values and generally should
// not be called manually.
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
// Stop is called to gracefully shutdown the hook.
func (h *HookBase) Stop() error {
return nil
}
// OnStarted is called when the server starts.
func (h *HookBase) OnStarted() {}
// OnStopped is called when the server stops.
func (h *HookBase) OnStopped() {}
// OnSysInfoTick is called when the server publishes system info.
func (h *HookBase) OnSysInfoTick(*system.Info) {}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return false
}
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnConnect is called when a new client connects.
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
// OnAuthPacket is called when an auth packet is received from the client.
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketRead is called when a packet is received.
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnPacketSent is called immediately after a packet is written to a client.
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
// OnPacketProcessed is called immediately after a packet from a client is processed.
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
// OnSubscribe is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
// OnSelectSubscribers is called when selecting subscribers to receive a message.
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
// OnPublish is called when a client publishes a message.
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
// OnClientExpired is called when a client session has expired.
func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return
}
// StoredSubscriptions returns all subcriptions from a store.
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
return
}
// StoredInflightMessages returns all inflight messages from a store.
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
return
}
// StoredRetainedMessages returns all retained messages from a store.
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
return
}
// StoredSysInfo returns a set of system info values.
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
return
}

View File

@ -1,41 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
)
// AllowHook is an authentication hook which allows connection access
// for all users and read and write access to all topics.
type AllowHook struct {
mqtt.HookBase
}
// ID returns the ID of the hook.
func (h *AllowHook) ID() string {
return "allow-all-auth"
}
// Provides indicates which hook methods this hook provides.
func (h *AllowHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// OnConnectAuthenticate returns true/allowed for all requests.
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
// OnACLCheck returns true/allowed for all checks.
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return true
}

View File

@ -1,35 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
func TestAllowAllID(t *testing.T) {
h := new(AllowHook)
require.Equal(t, "allow-all-auth", h.ID())
}
func TestAllowAllProvides(t *testing.T) {
h := new(AllowHook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublished))
}
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
}
func TestAllowAllOnACLCheck(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
}

View File

@ -1,107 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
)
// Options contains the configuration/rules data for the auth ledger.
type Options struct {
Data []byte
Ledger *Ledger
}
// Hook is an authentication hook which implements an auth ledger.
type Hook struct {
mqtt.HookBase
config *Options
ledger *Ledger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "auth-ledger"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// Init configures the hook with the auth ledger to be used for checking.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
var err error
if h.config.Ledger != nil {
h.ledger = h.config.Ledger
} else if len(h.config.Data) > 0 {
h.ledger = new(Ledger)
err = h.ledger.Unmarshal(h.config.Data)
}
if err != nil {
return err
}
if h.ledger == nil {
h.ledger = &Ledger{
Auth: AuthRules{},
ACL: ACLRules{},
}
}
h.Log.Info().
Int("authentication", len(h.ledger.Auth)).
Int("acl", len(h.ledger.ACL)).
Msg("loaded auth rules")
return nil
}
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
// in the auth ledger.
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
if _, ok := h.ledger.AuthOk(cl, pk); ok {
return true
}
h.Log.Info().
Str("username", string(pk.Connect.Username)).
Str("remote", cl.Net.Remote).
Msg("client failed authentication check")
return false
}
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
// or publish to a given topic.
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
return true
}
h.Log.Debug().
Str("client", cl.ID).
Str("username", string(cl.Properties.Username)).
Str("topic", topic).
Msg("client failed allowed ACL check")
return false
}

View File

@ -1,213 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"os"
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
// func teardown(t *testing.T, path string, h *Hook) {
// h.Stop()
// }
func TestBasicID(t *testing.T) {
h := new(Hook)
require.Equal(t, "auth-ledger", h.ID())
}
func TestBasicProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublish))
}
func TestBasicInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestBasicInitDefaultConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
}
func TestBasicInitWithLedgerPointer(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := &Ledger{
Auth: []AuthRule{
{
Remote: "127.0.0.1",
Allow: true,
},
},
ACL: []ACLRule{
{
Remote: "127.0.0.1",
Filters: Filters{
"#": ReadWrite,
},
},
},
}
err := h.Init(&Options{
Ledger: ln,
})
require.NoError(t, err)
require.Same(t, ln, h.ledger)
}
func TestBasicInitWithLedgerJSON(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerJSON,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerYAML(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerYAML,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: []byte("fdsfdsafasd"),
})
require.Error(t, err)
}
func TestOnConnectAuthenticate(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{},
packets.Packet{},
))
}
func TestOnACL(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"mochi/info",
true,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"d/j/f",
true,
))
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
false,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
true,
))
}

View File

@ -1,231 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"encoding/json"
"strings"
"sync"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"gopkg.in/yaml.v3"
)
const (
Deny Access = iota // user cannot access the topic
ReadOnly // user can only subscribe to the topic
WriteOnly // user can only publish to the topic
ReadWrite // user can both publish and subscribe to the topic
)
// Access determines the read/write privileges for an ACL rule.
type Access byte
// Users contains a map of access rules for specific users, keyed on username.
type Users map[string]UserRule
// UserRule defines a set of access rules for a specific user.
type UserRule struct {
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
}
// AuthRules defines generic access rules applicable to all users.
type AuthRules []AuthRule
type AuthRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
}
// ACLRules defines generic topic or filter access rules applicable to all users.
type ACLRules []ACLRule
// ACLRule defines access rules for a specific topic or filter.
type ACLRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
}
// Filters is a map of Access rules keyed on filter.
type Filters map[RString]Access
// RString is a rule value string.
type RString string
// Matches returns true if the rule matches a given string.
func (r RString) Matches(a string) bool {
rr := string(r)
if r == "" || r == "*" || a == rr {
return true
}
i := strings.Index(rr, "*")
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
return true
}
return false
}
// FilterMatches returns true if a filter matches a topic rule.
func (f RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(f), a)
return ok
}
// MatchTopic checks if a given topic matches a filter, accounting for filter
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
filterParts := strings.Split(filter, "/")
topicParts := strings.Split(topic, "/")
elements = make([]string, 0)
for i := 0; i < len(filterParts); i++ {
if i >= len(topicParts) {
matched = false
return
}
if filterParts[i] == "+" {
elements = append(elements, topicParts[i])
continue
}
if filterParts[i] == "#" {
matched = true
elements = append(elements, strings.Join(topicParts[i:], "/"))
return
}
if filterParts[i] != topicParts[i] {
matched = false
return
}
}
return elements, true
}
// Ledger is an auth ledger containing access rules for users and topics.
type Ledger struct {
sync.Mutex `json:"-" yaml:"-"`
Users Users `json:"users" yaml:"users"`
Auth AuthRules `json:"auth" yaml:"auth"`
ACL ACLRules `json:"acl" yaml:"acl"`
}
// Update updates the internal values of the ledger.
func (l *Ledger) Update(ln *Ledger) {
l.Lock()
defer l.Unlock()
l.Auth = ln.Auth
l.ACL = ln.ACL
}
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
u.Password != "" &&
u.Password == RString(pk.Connect.Password) {
return 0, !u.Disallow
}
}
// If there's no users map, or no user was found, attempt to find a matching
// rule (which may also contain a user).
for n, rule := range l.Auth {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Password.Matches(string(pk.Connect.Password)) &&
rule.Remote.Matches(cl.Net.Remote) {
return n, rule.Allow
}
}
return 0, false
}
// ACLOk returns true if the rules indicate the user is allowed to read or write to
// a specific filter or topic respectively, based on the write bool.
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
for filter, access := range u.ACL {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
for n, rule := range l.ACL {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Remote.Matches(cl.Net.Remote) {
if len(rule.Filters) == 0 {
return n, true
}
for filter, access := range rule.Filters {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
return 0, true
}
// ToJSON encodes the values into a JSON string.
func (l *Ledger) ToJSON() (data []byte, err error) {
return json.Marshal(l)
}
// ToYAML encodes the values into a YAML string.
func (l *Ledger) ToYAML() (data []byte, err error) {
return yaml.Marshal(l)
}
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
func (l *Ledger) Unmarshal(data []byte) error {
l.Lock()
defer l.Unlock()
if len(data) == 0 {
return nil
}
if data[0] == '{' {
return json.Unmarshal(data, l)
}
return yaml.Unmarshal(data, &l)
}

View File

@ -1,610 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
var (
checkLedger = Ledger{
Users: Users{ // users are allowed by default
"mochi-co": {
Password: "melon",
ACL: Filters{
"d/+/f": Deny,
"mochi-co/#": ReadWrite,
"readonly": ReadOnly,
},
},
"suspended-username": {
Password: "any",
Disallow: true,
},
"mochi": { // ACL only, will defer to AuthRules for authentication
ACL: Filters{
"special/mochi": ReadOnly,
"secret/mochi": Deny,
"ignored": ReadWrite,
},
},
},
Auth: AuthRules{
{Username: "banned-user"}, // never allow specific username
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
{Remote: "123.123.123.123"}, // disallow any from specific address
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
},
ACL: ACLRules{
{
Username: "mochi", // allow matching user/pass
Filters: Filters{
"a/b/c": Deny,
"d/+/f": Deny,
"mochi/#": ReadWrite,
"updates/#": WriteOnly,
"readonly": ReadOnly,
"ignored": Deny,
},
},
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
{Remote: "001.002.003.004"}, // Allow all with no filter
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
},
}
)
func TestRStringMatches(t *testing.T) {
require.True(t, RString("*").Matches("any"))
require.True(t, RString("*").Matches(""))
require.True(t, RString("").Matches("any"))
require.True(t, RString("").Matches(""))
require.False(t, RString("no").Matches("any"))
require.False(t, RString("no").Matches(""))
}
func TestCanAuthenticate(t *testing.T) {
tt := []struct {
desc string
client *mqtt.Client
pk packets.Packet
n int
ok bool
}{
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "deny client from address",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("not-mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.144.155.166",
},
},
pk: packets.Packet{},
ok: false,
n: 3,
},
{
desc: "allow remote wildcard",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.0.0.1",
},
},
pk: packets.Packet{},
ok: true,
n: 4,
},
{
desc: "never allow username",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("banned-user"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{},
ok: false,
n: 0,
},
{
desc: "matching user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi-co"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 0,
},
{
desc: "never user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("suspended-user"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
ok: false,
n: 0,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.AuthOk(d.client, d.pk)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestCanACL(t *testing.T) {
tt := []struct {
client *mqtt.Client
desc string
topic string
n int
write bool
ok bool
}{
{
desc: "allow normal write on any other filter",
client: &mqtt.Client{},
topic: "default/acl/write/access",
write: true,
ok: true,
},
{
desc: "allow normal read on any other filter",
client: &mqtt.Client{},
topic: "default/acl/read/access",
write: false,
ok: true,
},
{
desc: "deny user on literal filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "a/b/c",
},
{
desc: "deny user on partial filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "d/j/f",
},
{
desc: "allow read/write to user path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "mochi/read/write",
write: true,
ok: true,
},
{
desc: "deny read on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/no/reading",
write: false,
ok: false,
},
{
desc: "deny read on write-only path ext",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: false,
ok: false,
},
{
desc: "allow read on not-acl path (no #)",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates",
write: false,
ok: true,
},
{
desc: "allow write on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: true,
ok: true,
},
{
desc: "deny write on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: true,
ok: false,
},
{
desc: "allow read on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: false,
ok: true,
},
{
desc: "allow $sys access to localhost",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "localhost",
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 1,
},
{
desc: "allow $sys access to admin",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("admin"),
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 2,
},
{
desc: "deny $sys access to all others",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "$SYS/test",
write: false,
ok: false,
n: 4,
},
{
desc: "allow all with no filter",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "001.002.003.004",
},
},
topic: "any/path",
write: true,
ok: true,
n: 3,
},
{
desc: "use users embedded acl deny",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "secret/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl any",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "any/mochi",
write: true,
ok: true,
},
{
desc: "use users embedded acl write on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl read on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: false,
ok: true,
},
{
desc: "preference users embedded acl",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "ignored",
write: true,
ok: true,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestMatchTopic(t *testing.T) {
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "d"}, el)
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "c", "d"}, el)
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
require.True(t, matched)
require.Equal(t, []string{"things/yeah"}, el)
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
require.True(t, matched)
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
el, matched = MatchTopic("test", "test")
require.True(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("things/stuff//", "things/stuff/")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("t", "t2")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic(" ", " ")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
}
var (
ledgerStruct = Ledger{
Users: Users{
"mochi": {
Password: "peach",
ACL: Filters{
"readonly": ReadOnly,
"deny": Deny,
},
},
},
Auth: AuthRules{
{
Client: "*",
Username: "mochi-co",
Password: "melon",
Remote: "192.168.1.*",
Allow: true,
},
},
ACL: ACLRules{
{
Client: "*",
Username: "mochi-co",
Remote: "127.*",
Filters: Filters{
"readonly": ReadOnly,
"writeonly": WriteOnly,
"readwrite": ReadWrite,
"deny": Deny,
},
},
},
}
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
ledgerYAML = []byte(`users:
mochi:
password: peach
acl:
deny: 0
readonly: 1
auth:
- client: '*'
username: mochi-co
remote: 192.168.1.*
password: melon
allow: true
acl:
- client: '*'
username: mochi-co
remote: 127.*
filters:
deny: 0
readonly: 1
readwrite: 3
writeonly: 2
`)
)
func TestLedgerUpdate(t *testing.T) {
old := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
},
}
new := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
{Remote: "192.168.*", Allow: true},
},
}
old.Update(new)
require.Len(t, old.Auth, 2)
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
require.NotSame(t, new, old)
}
func TestLedgerToJSON(t *testing.T) {
data, err := ledgerStruct.ToJSON()
require.NoError(t, err)
require.Equal(t, ledgerJSON, data)
}
func TestLedgerToYAML(t *testing.T) {
data, err := ledgerStruct.ToYAML()
require.NoError(t, err)
require.Equal(t, ledgerYAML, data)
}
func TestLedgerUnmarshalFromYAML(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerYAML)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalFromJSON(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerJSON)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalNil(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal([]byte{})
require.NoError(t, err)
require.Equal(t, new(Ledger), l)
}

View File

@ -1,250 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (
"strings"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
)
// Options contains configuration settings for the debug output.
type Options struct {
ShowPacketData bool // include decoded packet data (default false)
ShowPings bool // show ping requests and responses (default false)
ShowPasswords bool // show connecting user passwords (default false)
}
// Hook is a debugging hook which logs additional low-level information from the server.
type Hook struct {
mqtt.HookBase
config *Options
Log *zerolog.Logger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "debug"
}
// Provides indicates that this hook provides all methods.
func (h *Hook) Provides(b byte) bool {
return true
}
// Init is called when the hook is initialized.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
return nil
}
// SetOpts is called when the hook receives inheritable server parameters.
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
h.Log = l
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
}
// Stop is called when the hook is stopped.
func (h *Hook) Stop() error {
h.Log.Debug().Str("method", "Stop").Send()
return nil
}
// OnStarted is called when the server starts.
func (h *Hook) OnStarted() {
h.Log.Debug().Str("method", "OnStarted").Send()
}
// OnStopped is called when the server stops.
func (h *Hook) OnStopped() {
h.Log.Debug().Str("method", "OnStopped").Send()
}
// OnPacketRead is called when a new packet is received from a client.
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return pk, nil
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
return pk, nil
}
// OnPacketSent is called when a packet is sent to a client.
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
}
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
}
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
}
// OnQosDropped is called the Qos flow for a message expires.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
}
// OnLWTSent is called when a will message has been issued from a disconnecting client.
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
}
// OnRetainedExpired is called when the server clears expired retained messages.
func (h *Hook) OnRetainedExpired(filter string) {
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
}
// OnClientExpired is called when the server clears an expired client.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
}
// StoredClients is called when the server restores clients from a store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
return v, nil
}
// StoredClients is called when the server restores subscriptions from a store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
h.Log.Debug().
Str("method", "StoredSubscriptions").
Send()
return v, nil
}
// StoredClients is called when the server restores retained messages from a store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredRetainedMessages").
Send()
return v, nil
}
// StoredClients is called when the server restores inflight messages from a store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredInflightMessages").
Send()
return v, nil
}
// StoredClients is called when the server restores system info from a store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
return v, nil
}
// packetMeta adds additional type-specific metadata to the debug logs.
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
m := map[string]any{}
switch pk.FixedHeader.Type {
case packets.Connect:
m["id"] = pk.Connect.ClientIdentifier
m["clean"] = pk.Connect.Clean
m["keepalive"] = pk.Connect.Keepalive
m["version"] = pk.ProtocolVersion
m["username"] = string(pk.Connect.Username)
if h.config.ShowPasswords {
m["password"] = string(pk.Connect.Password)
}
if pk.Connect.WillFlag {
m["will_topic"] = pk.Connect.WillTopic
m["will_payload"] = string(pk.Connect.WillPayload)
}
case packets.Publish:
m["topic"] = pk.TopicName
m["payload"] = string(pk.Payload)
m["raw"] = pk.Payload
m["qos"] = pk.FixedHeader.Qos
m["id"] = pk.PacketID
case packets.Connack:
fallthrough
case packets.Disconnect:
fallthrough
case packets.Puback:
fallthrough
case packets.Pubrec:
fallthrough
case packets.Pubrel:
fallthrough
case packets.Pubcomp:
m["id"] = pk.PacketID
m["reason"] = int(pk.ReasonCode)
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
m["reason_string"] = pk.Properties.ReasonString
}
case packets.Subscribe:
f := map[string]int{}
ids := map[string]int{}
for _, v := range pk.Filters {
f[v.Filter] = int(v.Qos)
ids[v.Filter] = v.Identifier
}
m["filters"] = f
m["subids"] = f
case packets.Unsubscribe:
f := []string{}
for _, v := range pk.Filters {
f = append(f, v.Filter)
}
m["filters"] = f
case packets.Suback:
fallthrough
case packets.Unsuback:
r := []int{}
for _, v := range pk.ReasonCodes {
r = append(r, int(v))
}
m["reasons"] = r
case packets.Auth:
// tbd
}
if h.config.ShowPacketData {
m["packet"] = pk
}
return m
}

View File

@ -1,473 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"bytes"
"errors"
"strings"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/timshannon/badgerhold"
)
const (
// defaultDbFile is the default file path for the badger db file.
defaultDbFile = ".badger"
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the BadgerDB instance.
type Options struct {
Options *badgerhold.Options
Path string
}
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the BadgerDB instance.
db *badgerhold.Store // the BadgerDB instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "badger-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the badger instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Path == "" {
h.config.Path = defaultDbFile
}
options := badgerhold.DefaultOptions
options.Dir = h.config.Path
options.ValueDir = h.config.Path
options.Logger = h
var err error
h.db, err = badgerhold.Open(options)
if err != nil {
return err
}
return nil
}
// Stop closes the badger instance.
func (h *Hook) Stop() error {
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
}
}
// OnDisconnect removes a client from the store if their session has expired.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
h.updateClient(cl)
if !expire {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
if err != nil {
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
PacketID: pk.PacketID,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.ClientKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.SubscriptionKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.RetainedKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Get(storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// Errorf satisfies the badger interface for an error logger.
func (h *Hook) Errorf(m string, v ...interface{}) {
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Warningf satisfies the badger interface for a warning logger.
func (h *Hook) Warningf(m string, v ...interface{}) {
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Infof satisfies the badger interface for an info logger.
func (h *Hook) Infof(m string, v ...interface{}) {
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Debugf satisfies the badger interface for a debug logger.
func (h *Hook) Debugf(m string, v ...interface{}) {
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}

View File

@ -1,681 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"os"
"strings"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/timshannon/badgerhold"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
h.db.Badger().Close()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "badger-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.db.Get(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.db.Get(clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.db.Get(clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, badgerhold.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.db.Upsert(clientKey, &storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.db.Get(clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.db.Get(clientKey, r)
require.Error(t, err)
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.db.Get(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, badgerhold.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.db.Upsert(m.ID, m)
require.NoError(t, err)
r := new(storage.Message)
err = h.db.Get(m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.db.Get(m.ID, r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.db.Get(inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.db.Get(inflightKey(client, pk), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.db.Get(storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.db.Upsert("cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Upsert("cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Upsert("cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.db.Upsert("sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Upsert("sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Upsert("sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert(storage.SysInfoKey, &storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestErrorf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Errorf("test", 1, 2, 3)
}
func TestWarningf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Warningf("test", 1, 2, 3)
}
func TestInfof(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Infof("test", 1, 2, 3)
}
func TestDebugf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Debugf("test", 1, 2, 3)
}

View File

@ -1,474 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
package bolt
import (
"bytes"
"errors"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3"
"go.etcd.io/bbolt"
)
const (
// defaultDbFile is the default file path for the boltdb file.
defaultDbFile = "bolt.db"
// defaultTimeout is the default time to hold a connection to the file.
defaultTimeout = 250 * time.Millisecond
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
Options *bbolt.Options
Path string
}
// Hook is a persistent storage hook based using boltdb file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "bolt-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the boltdb instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &bbolt.Options{
Timeout: defaultTimeout,
}
}
if h.config.Path == "" {
h.config.Path = defaultDbFile
}
var err error
h.db, err = storm.Open(h.config.Path, storm.BoltOptions(0600, h.config.Options), storm.Codec(sgob.Codec))
if err != nil {
return err
}
return nil
}
// Stop closes the boltdb instance.
func (h *Hook) Stop() error {
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.DeleteStruct(&storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
Msg("failed to delete client")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.DeleteStruct(&storage.Message{
ID: retainedKey(pk.TopicName),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", retainedKey(pk.TopicName)).
Msg("failed to delete retained publish")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save retained publish data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save qos inflight data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Message{
ID: inflightKey(cl, pk),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", inflightKey(cl, pk)).
Msg("failed to delete inflight data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Interface("data", in).
Msg("failed to save $SYS data")
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.ClientKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.SubscriptionKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.RetainedKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.One("ID", storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}

View File

@ -1,717 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package bolt
import (
"os"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/asdine/storm/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
err := os.Remove(path)
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "bolt-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestInitBadPath(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Path: "..",
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.db.One("ID", clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.db.One("ID", clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.db.One("ID", clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, storm.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.db.One("ID", clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.db.Save(&storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.db.One("ID", clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.db.One("ID", clientKey, r)
require.Error(t, err)
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.db.Save(m)
require.NoError(t, err)
r := new(storage.Message)
err = h.db.One("ID", m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.db.One("ID", m.ID, r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.db.One("ID", inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.db.One("ID", inflightKey(client, pk), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.db.One("ID", storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.db.Save(&storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Save(&storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Save(&storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.db.Save(&storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Save(&storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Save(&storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with sys info
err = h.db.Save(&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

View File

@ -1,513 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
redis "github.com/go-redis/redis/v8"
)
// defaultAddr is the default address to the redis service.
const defaultAddr = "localhost:6379"
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
const defaultHPrefix = "mochi-"
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
HPrefix string
Options *redis.Options
}
// Hook is a persistent storage hook based using Redis as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for connecting to the Redis instance.
db *redis.Client // the Redis instance
ctx context.Context // a context for the connection
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "redis-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnWillSent,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// hKey returns a hash set key with a unique prefix.
func (h *Hook) hKey(s string) string {
return h.config.HPrefix + s
}
// Init initializes and connects to the redis service.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
h.ctx = context.Background()
if config == nil {
config = &Options{
Options: &redis.Options{
Addr: defaultAddr,
},
}
}
h.config = config.(*Options)
if h.config.HPrefix == "" {
h.config.HPrefix = defaultHPrefix
}
h.Log.Info().
Str("address", h.config.Options.Addr).
Str("username", h.config.Options.Username).
Int("password-len", len(h.config.Options.Password)).
Int("db", h.config.Options.DB).
Msg("connecting to redis service")
h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result()
if err != nil {
return fmt.Errorf("failed to ping service: %w", err)
}
h.Log.Info().Msg("connected to redis service")
return nil
}
// Close closes the redis connection.
func (h *Hook) Stop() error {
h.Log.Info().Msg("disconnecting from redis service")
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
return
}
for _, row := range rows {
var d storage.Client
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
}
v = append(v, d)
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
return
}
for _, row := range rows {
var d storage.Subscription
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
}
v = append(v, d)
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
}
v = append(v, d)
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
v = append(v, d)
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return
}
if err = v.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
}
return v, nil
}

View File

@ -1,789 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"os"
"sort"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
miniredis "github.com/alicebob/miniredis/v2"
redis "github.com/go-redis/redis/v8"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func newHook(t *testing.T, addr string) *Hook {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: addr,
},
})
require.NoError(t, err)
return h
}
func teardown(t *testing.T, h *Hook) {
if h.db != nil {
err := h.db.FlushAll(h.ctx).Err()
require.NoError(t, err)
h.Stop()
}
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, "cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, "a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, "cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Equal(t, "redis-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestHKey(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.SetOpts(&logger, nil)
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
}
func TestInitUseDefaults(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, defaultAddr)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h)
require.Equal(t, defaultHPrefix, h.config.HPrefix)
require.Equal(t, defaultAddr, h.config.Options.Addr)
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitBadAddr(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: "abc:123",
},
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r2.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
require.NoError(t, err)
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, clientKey, r.ID)
h.OnClientExpired(cl)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.Error(t, err)
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnSubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
require.NoError(t, err)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnQosPublishNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with clients
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with subscriptions
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with sys info
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
}).Err()
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

View File

@ -1,164 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"encoding/json"
"errors"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
)
const (
SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store
SysInfoKey = "SYS" // unique key to denote server system information in a store
RetainedKey = "RET" // unique key to denote retained messages in a store
InflightKey = "IFM" // unique key to denote inflight messages in a store
ClientKey = "CL" // unique key to denote clients in a store
)
var (
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
ErrDBFileNotOpen = errors.New("db file not open")
)
// Client is a storable representation of an mqtt client.
type Client struct {
Will ClientWill `json:"will"` // will topic and payload data if applicable
Properties ClientProperties `json:"properties"` // the connect properties for the client
Username []byte `json:"username"` // the username of the client
ID string `json:"id" storm:"id"` // the client id / storage key
T string `json:"t"` // the data type (client)
Remote string `json:"remote"` // the remote address of the client
Listener string `json:"listener"` // the listener the client connected on
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
Clean bool `json:"clean"` // if the client requested a clean start/session
}
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
type ClientProperties struct {
AuthenticationData []byte `json:"authenticationData"`
User []packets.UserProperty `json:"user"`
AuthenticationMethod string `json:"authenticationMethod"`
SessionExpiryInterval uint32 `json:"sessionExpiryInterval"`
MaximumPacketSize uint32 `json:"maximumPacketSize"`
ReceiveMaximum uint16 `json:"receiveMaximum"`
TopicAliasMaximum uint16 `json:"topicAliasMaximum"`
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag"`
RequestProblemInfo byte `json:"requestProblemInfo"`
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag"`
RequestResponseInfo byte `json:"requestResponseInfo"`
}
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
type ClientWill struct {
Payload []byte `json:"payload"`
User []packets.UserProperty `json:"user"`
TopicName string `json:"topicName"`
Flag uint32 `json:"flag"`
WillDelayInterval uint32 `json:"willDelayInterval"`
Qos byte `json:"qos"`
Retain bool `json:"retain"`
}
// MarshalBinary encodes the values into a json string.
func (d Client) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Client) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Message is a storable representation of an MQTT message (specifically publish).
type Message struct {
Properties MessageProperties `json:"properties"` // -
Payload []byte `json:"payload"` // the message payload (if retained)
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
Origin string `json:"origin"` // the id of the client who sent the message
TopicName string `json:"topic_name"` // the topic the message was sent to (if retained)
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
Created int64 `json:"created"` // the time the message was created in unixtime
Sent int64 `json:"sent"` // the last time the message was sent (for retries) in unixtime (if inflight)
PacketID uint16 `json:"packet_id"` // the unique id of the packet (if inflight)
}
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
type MessageProperties struct {
CorrelationData []byte `json:"correlationData"`
SubscriptionIdentifier []int `json:"subscriptionIdentifier"`
User []packets.UserProperty `json:"user"`
ContentType string `json:"contentType"`
ResponseTopic string `json:"responseTopic"`
MessageExpiryInterval uint32 `json:"messageExpiry"`
TopicAlias uint16 `json:"topicAlias"`
PayloadFormat byte `json:"payloadFormat"`
PayloadFormatFlag bool `json:"payloadFormatFlag"`
}
// MarshalBinary encodes the values into a json string.
func (d Message) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Message) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Subscription is a storable representation of an mqtt subscription.
type Subscription struct {
T string `json:"t"`
ID string `json:"id" storm:"id"`
Client string `json:"client"`
Filter string `json:"filter"`
Identifier int `json:"identifier"`
RetainHandling byte `json:"retain_handling"`
Qos byte `json:"qos"`
RetainAsPublished bool `json:"retain_as_pub"`
NoLocal bool `json:"no_local"`
}
// MarshalBinary encodes the values into a json string.
func (d Subscription) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Subscription) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// SystemInfo is a storable representation of the system information values.
type SystemInfo struct {
system.Info // embed the system info struct
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
}
// MarshalBinary encodes the values into a json string.
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}

View File

@ -1,196 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"testing"
"time"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
var (
clientStruct = Client{
ID: "test",
T: "client",
Remote: "remote",
Listener: "listener",
Username: []byte("mochi"),
Clean: true,
Properties: ClientProperties{
SessionExpiryInterval: 2,
SessionExpiryIntervalFlag: true,
AuthenticationMethod: "a",
AuthenticationData: []byte("test"),
RequestProblemInfo: 1,
RequestProblemInfoFlag: true,
RequestResponseInfo: 1,
ReceiveMaximum: 128,
TopicAliasMaximum: 256,
User: []packets.UserProperty{
{Key: "k", Val: "v"},
},
MaximumPacketSize: 120,
},
Will: ClientWill{
Qos: 1,
Payload: []byte("abc"),
TopicName: "a/b/c",
Flag: 1,
Retain: true,
WillDelayInterval: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
}
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
messageStruct = Message{
T: "message",
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: 2,
Type: 3,
Qos: 1,
Dup: true,
Retain: true,
},
ID: "id",
Origin: "mochi",
TopicName: "topic",
Properties: MessageProperties{
PayloadFormat: 1,
PayloadFormatFlag: true,
MessageExpiryInterval: 20,
ContentType: "type",
ResponseTopic: "a/b/r",
CorrelationData: []byte("r"),
SubscriptionIdentifier: []int{1},
TopicAlias: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
PacketID: 100,
}
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
subscriptionStruct = Subscription{
T: "subscription",
ID: "id",
Client: "mochi",
Filter: "a/b/c",
Qos: 1,
}
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","identifier":0,"retain_handling":0,"qos":1,"retain_as_pub":false,"no_local":false}`)
sysInfoStruct = SystemInfo{
T: "info",
ID: "id",
Info: system.Info{
Version: "2.0.0",
Started: 1,
Uptime: 2,
BytesReceived: 3,
BytesSent: 4,
ClientsConnected: 5,
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
Inflight: 16,
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
data, err := clientStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, clientJSON, data)
}
func TestClientUnmarshalBinary(t *testing.T) {
d := clientStruct
err := d.UnmarshalBinary(clientJSON)
require.NoError(t, err)
require.Equal(t, clientStruct, d)
}
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
d := Client{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Client{}, d)
}
func TestMessageMarshalBinary(t *testing.T) {
data, err := messageStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, messageJSON, data)
}
func TestMessageUnmarshalBinary(t *testing.T) {
d := messageStruct
err := d.UnmarshalBinary(messageJSON)
require.NoError(t, err)
require.Equal(t, messageStruct, d)
}
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
d := Message{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Message{}, d)
}
func TestSubscriptionMarshalBinary(t *testing.T) {
data, err := subscriptionStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, subscriptionJSON, data)
}
func TestSubscriptionUnmarshalBinary(t *testing.T) {
d := subscriptionStruct
err := d.UnmarshalBinary(subscriptionJSON)
require.NoError(t, err)
require.Equal(t, subscriptionStruct, d)
}
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
d := Subscription{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Subscription{}, d)
}
func TestSysInfoMarshalBinary(t *testing.T) {
data, err := sysInfoStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, sysInfoJSON, data)
}
func TestSysInfoUnmarshalBinary(t *testing.T) {
d := sysInfoStruct
err := d.UnmarshalBinary(sysInfoJSON)
require.NoError(t, err)
require.Equal(t, sysInfoStruct, d)
}
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
d := SystemInfo{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}

View File

@ -1,634 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
type modifiedHookBase struct {
HookBase
err error
fail bool
failAt int
}
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) Provides(b byte) bool {
return true
}
func (h *modifiedHookBase) Stop() error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return true
}
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return true
}
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
if h.fail {
return will, errTestHook
}
return will, nil
}
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
if h.fail || h.failAt == 1 {
return v, errTestHook
}
return []storage.Client{
{ID: "cl1"},
{ID: "cl2"},
{ID: "cl3"},
}, nil
}
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.fail || h.failAt == 2 {
return v, errTestHook
}
return []storage.Subscription{
{ID: "sub1"},
{ID: "sub2"},
{ID: "sub3"},
}, nil
}
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 3 {
return v, errTestHook
}
return []storage.Message{
{ID: "r1"},
{ID: "r2"},
{ID: "r3"},
}, nil
}
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 4 {
return v, errTestHook
}
return []storage.Message{
{ID: "i1"},
{ID: "i2"},
{ID: "i3"},
}, nil
}
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.fail || h.failAt == 5 {
return v, errTestHook
}
return storage.SystemInfo{
Info: system.Info{
Version: "2.0.0",
},
}, nil
}
type providesCheckHook struct {
HookBase
}
func (h *providesCheckHook) Provides(b byte) bool {
return b == OnConnect
}
func TestHooksProvides(t *testing.T) {
h := new(Hooks)
err := h.Add(new(providesCheckHook), nil)
require.NoError(t, err)
err = h.Add(new(HookBase), nil)
require.NoError(t, err)
require.True(t, h.Provides(OnConnect, OnDisconnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), map[string]any{})
require.Error(t, err)
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
}
func TestHooksStop(t *testing.T) {
h := new(Hooks)
h.Log = &logger
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
h.Stop()
}
// coverage: also cover some empty functions
func TestHooksNonReturns(t *testing.T) {
h := new(Hooks)
cl := new(Client)
for i := 0; i < 2; i++ {
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
// on first iteration, check without hook methods
h.OnStarted()
h.OnStopped()
h.OnSysInfoTick(new(system.Info))
h.OnConnect(cl, packets.Packet{})
h.OnSessionEstablished(cl, packets.Packet{})
h.OnDisconnect(cl, nil, false)
h.OnPacketSent(cl, packets.Packet{}, []byte{})
h.OnPacketProcessed(cl, packets.Packet{}, nil)
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
})
}
}
func TestHooksOnConnectAuthenticate(t *testing.T) {
h := new(Hooks)
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.True(t, ok)
}
func TestHooksOnACLCheck(t *testing.T) {
h := new(Hooks)
ok := h.OnACLCheck(new(Client), "a/b/c", true)
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnACLCheck(new(Client), "a/b/c", true)
require.True(t, ok)
}
func TestHooksOnSubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnSubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnSelectSubscribers(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
subs := &Subscribers{
Subscriptions: map[string]packets.Subscription{
"cl1": {Filter: "a/b/c"},
},
}
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
require.EqualValues(t, subs, subs2)
}
func TestHooksOnUnsubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnUnsubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnPublish(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketRead(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnAuthPacket(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
hook.fail = true
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketEncode(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnLWT(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
// coverage: fail lwt
hook.fail = true
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHooksStoredClients(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredClients()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSubscriptions(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredSubscriptions()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredRetainedMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredRetainedMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredInflightMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredInflightMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSysInfo(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Info.Version)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", v.Info.Version)
hook.fail = true
v, err = h.StoredSysInfo()
require.Error(t, err)
require.Equal(t, "", v.Info.Version)
}
func TestHookBaseID(t *testing.T) {
h := new(HookBase)
require.Equal(t, "base", h.ID())
}
func TestHookBaseProvidesNone(t *testing.T) {
h := new(HookBase)
require.False(t, h.Provides(OnConnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHookBaseInit(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Init(nil))
}
func TestHookBaseSetOpts(t *testing.T) {
h := new(HookBase)
h.SetOpts(&logger, new(HookOptions))
require.NotNil(t, h.Log)
require.NotNil(t, h.Opts)
}
func TestHookBaseClose(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Stop())
}
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
h := new(HookBase)
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, v)
}
func TestHookBaseOnACLCheck(t *testing.T) {
h := new(HookBase)
v := h.OnACLCheck(new(Client), "topic", true)
require.False(t, v)
}
func TestHookBaseOnPublish(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnPacketRead(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnAuthPacket(t *testing.T) {
h := new(HookBase)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnLWT(t *testing.T) {
h := new(HookBase)
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.NoError(t, err)
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHookBaseStoredClients(t *testing.T) {
h := new(HookBase)
v, err := h.StoredClients()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredSubscriptions(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredInflightMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredRetainedMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoreSysInfo(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Version)
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 42 KiB

View File

@ -1,156 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sort"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/packets"
)
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]packets.Packet // internal contains the inflight packets
receiveQuota int32 // remaining inbound qos quota for flow control
sendQuota int32 // remaining outbound qos quota for flow control
maximumReceiveQuota int32 // maximum allowed receive quota
maximumSendQuota int32 // maximum allowed send quota
}
// NewInflights returns a new instance of an Inflight packets map.
func NewInflights() *Inflight {
return &Inflight{
internal: map[uint16]packets.Packet{},
}
}
// Set adds or updates an inflight packet by packet id.
func (i *Inflight) Set(m packets.Packet) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[m.PacketID]
i.internal[m.PacketID] = m
return !ok
}
// Get returns an inflight packet by packet id.
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
if m, ok := i.internal[id]; ok {
return m, true
}
return packets.Packet{}, false
}
// Len returns the size of the inflight messages map.
func (i *Inflight) Len() int {
i.RLock()
defer i.RUnlock()
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
defer i.RUnlock()
m := []packets.Packet{}
for _, v := range i.internal {
if !immediate || (immediate && v.Expiry < 0) {
m = append(m, v)
}
}
sort.Slice(m, func(i, j int) bool {
return uint16(m[i].Created) < uint16(m[j].Created)
})
return m
}
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
// is free to continue sending.
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
m := i.GetAll(true)
if len(m) > 0 {
return m[0], true
}
return packets.Packet{}, false
}
// Delete removes an in-flight message from the map. Returns true if the message existed.
func (i *Inflight) Delete(id uint16) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[id]
delete(i.internal, id)
return ok
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
}
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.receiveQuota, n)
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}
}
// ResetSendQuota resets the send quota to the maximum allowed value.
func (i *Inflight) ResetSendQuota(n int32) {
atomic.StoreInt32(&i.sendQuota, n)
atomic.StoreInt32(&i.maximumSendQuota, n)
}

View File

@ -1,199 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sync/atomic"
"testing"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
require.NotNil(t, cl.State.Inflight.internal[1])
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.False(t, r)
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
require.True(t, ok)
require.NotEqual(t, 0, msg.PacketID)
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
require.Equal(t, []packets.Packet{
{PacketID: 1, Created: 1},
{PacketID: 2, Created: 2},
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
{PacketID: 5, Created: 5},
}, cl.State.Inflight.GetAll(false))
require.Equal(t, []packets.Packet{
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
}, cl.State.Inflight.GetAll(true))
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
r := cl.State.Inflight.Delete(3)
require.True(t, r)
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
_, ok := cl.State.Inflight.Get(3)
require.False(t, ok)
r = cl.State.Inflight.Delete(3)
require.False(t, r)
}
func TestResetReceiveQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
i.ResetReceiveQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
}
func TestReceiveQuota(t *testing.T) {
i := NewInflights()
i.receiveQuota = 4
i.maximumReceiveQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Reset to max 1
i.ResetReceiveQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
func TestResetSendQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
i.ResetSendQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
}
func TestSendQuota(t *testing.T) {
i := NewInflights()
i.sendQuota = 4
i.maximumSendQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Reset to max 1
i.ResetSendQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
pk, ok := cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
r := cl.State.Inflight.Delete(3)
require.True(t, r)
pk, ok = cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
r = cl.State.Inflight.Delete(4)
require.True(t, r)
_, ok = cl.State.Inflight.NextImmediate()
require.False(t, ok)
}

View File

@ -1,118 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"encoding/json"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
log *zerolog.Logger // server logger
sysInfo *system.Info // pointers to the server data
end uint32 // ensure the close methods are only called once
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats {
if config == nil {
config = new(Config)
}
return &HTTPStats{
id: id,
address: address,
sysInfo: sysInfo,
config: config,
}
}
// ID returns the id of the listener.
func (l *HTTPStats) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *HTTPStats) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *HTTPStats) Protocol() string {
if l.listen != nil && l.listen.TLSConfig != nil {
return "https"
}
return "http"
}
// Init initializes the listener.
func (l *HTTPStats) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
Addr: l.address,
Handler: mux,
}
if l.config.TLSConfig != nil {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPStats) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := *l.sysInfo.Clone()
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
}
w.Write(out)
}

View File

@ -1,127 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "t1", l.ID())
}
func TestHTTPStatsAddress(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, testAddr, l.Address())
}
func TestHTTPStatsProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "http", l.Protocol())
}
func TestHTTPStatsTLSProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
}, nil)
l.Init(nil)
require.Equal(t, "https", l.Protocol())
}
func TestHTTPStatsInit(t *testing.T) {
sysInfo := new(system.Info)
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.sysInfo)
require.Equal(t, sysInfo, l.sysInfo)
require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr)
}
func TestHTTPStatsServeAndClose(t *testing.T) {
sysInfo := &system.Info{
Version: "test",
}
// setup http stats listener
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// get body from stats address
resp, err := http.Get("http://localhost" + testAddr)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// decode body from json and check data
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost" + testAddr)
require.Error(t, err)
<-o
}
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
sysInfo := &system.Info{
Version: "test",
}
l := NewHTTPStats("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
}, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.Close(MockCloser)
}

View File

@ -1,135 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"net"
"sync"
"github.com/rs/zerolog"
)
// Config contains configuration values for a listener.
type Config struct {
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// EstablishFn is a callback function for establishing new clients.
type EstablishFn func(id string, c net.Conn) error
// CloseFunc is a callback function for closing all listener clients.
type CloseFn func(id string)
// Listener is an interface for network listeners. A network listener listens
// for incoming client connections and adds them to the server.
type Listener interface {
Init(*zerolog.Logger) error // open the network address
Serve(EstablishFn) // starting actively listening for new connections
ID() string // return the id of the listener
Address() string // the address of the listener
Protocol() string // the protocol in use by the listener
Close(CloseFn) // stop and close the listener
}
// Listeners contains the network listeners for the broker.
type Listeners struct {
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
sync.RWMutex
}
// New returns a new instance of Listeners.
func New() *Listeners {
return &Listeners{
internal: map[string]Listener{},
}
}
// Add adds a new listener to the listeners map, keyed on id.
func (l *Listeners) Add(val Listener) {
l.Lock()
defer l.Unlock()
l.internal[val.ID()] = val
}
// Get returns the value of a listener if it exists.
func (l *Listeners) Get(id string) (Listener, bool) {
l.RLock()
defer l.RUnlock()
val, ok := l.internal[id]
return val, ok
}
// Len returns the length of the listeners map.
func (l *Listeners) Len() int {
l.RLock()
defer l.RUnlock()
return len(l.internal)
}
// Delete removes a listener from the internal map.
func (l *Listeners) Delete(id string) {
l.Lock()
defer l.Unlock()
delete(l.internal, id)
}
// Serve starts a listener serving from the internal map.
func (l *Listeners) Serve(id string, establisher EstablishFn) {
l.RLock()
defer l.RUnlock()
listener := l.internal[id]
go func(e EstablishFn) {
defer l.wg.Done()
l.wg.Add(1)
listener.Serve(e)
}(establisher)
}
// ServeAll starts all listeners serving from the internal map.
func (l *Listeners) ServeAll(establisher EstablishFn) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Serve(id, establisher)
}
}
// Close stops a listener from the internal map.
func (l *Listeners) Close(id string, closer CloseFn) {
l.RLock()
defer l.RUnlock()
if listener, ok := l.internal[id]; ok {
listener.Close(closer)
}
}
// CloseAll iterates and closes all registered listeners.
func (l *Listeners) CloseAll(closer CloseFn) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Close(id, closer)
}
l.wg.Wait()
}

View File

@ -1,177 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"log"
"os"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
const testAddr = ":22222"
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New()
require.NotNil(t, l.internal)
}
func TestAddListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
require.Contains(t, l.internal, "t1")
}
func TestGetListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
g, ok := l.Get("t1")
require.True(t, ok)
require.Equal(t, g.ID(), "t1")
}
func TestLenListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
require.Equal(t, 2, l.Len())
}
func TestDeleteListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
require.Contains(t, l.internal, "t1")
l.Delete("t1")
_, ok := l.Get("t1")
require.False(t, ok)
require.Nil(t, l.internal["t1"])
}
func TestServeListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
require.False(t, l.internal["t1"].(*MockListener).IsServing())
}
func TestServeAllListeners(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
l.Add(NewMockListener("t3", testAddr))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
require.True(t, l.internal["t2"].(*MockListener).IsServing())
require.True(t, l.internal["t3"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
l.Close("t2", MockCloser)
l.Close("t3", MockCloser)
require.False(t, l.internal["t1"].(*MockListener).IsServing())
require.False(t, l.internal["t2"].(*MockListener).IsServing())
require.False(t, l.internal["t3"].(*MockListener).IsServing())
}
func TestCloseListener(t *testing.T) {
l := New()
mocked := NewMockListener("t1", testAddr)
l.Add(mocked)
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
var closed bool
l.Close("t1", func(id string) {
closed = true
})
require.True(t, closed)
}
func TestCloseAllListeners(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
l.Add(NewMockListener("t3", testAddr))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
require.True(t, l.internal["t2"].(*MockListener).IsServing())
require.True(t, l.internal["t3"].(*MockListener).IsServing())
closed := make(map[string]bool)
l.CloseAll(func(id string) {
closed[id] = true
})
require.Contains(t, closed, "t1")
require.Contains(t, closed, "t2")
require.Contains(t, closed, "t3")
require.True(t, closed["t1"])
require.True(t, closed["t2"])
require.True(t, closed["t3"])
}

View File

@ -1,103 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"fmt"
"net"
"sync"
"github.com/rs/zerolog"
)
// MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn) error {
return nil
}
// MockCloser is a function signature which can be used in testing.
func MockCloser(id string) {}
// MockListener is a mock listener for establishing client connections.
type MockListener struct {
sync.RWMutex
id string // the id of the listener
address string // the network address the listener binds to
Config *Config // configuration for the listener
done chan bool // indicate the listener is done
Serving bool // indicate the listener is serving
Listening bool // indiciate the listener is listening
ErrListen bool // throw an error on listen
}
// NewMockListener returns a new instance of MockListener.
func NewMockListener(id, address string) *MockListener {
return &MockListener{
id: id,
address: address,
done: make(chan bool),
}
}
// Serve serves the mock listener.
func (l *MockListener) Serve(establisher EstablishFn) {
l.Lock()
l.Serving = true
l.Unlock()
for range l.done {
return
}
}
// Init initializes the listener.
func (l *MockListener) Init(log *zerolog.Logger) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}
l.Lock()
defer l.Unlock()
l.Listening = true
return nil
}
// ID returns the id of the mock listener.
func (l *MockListener) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *MockListener) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *MockListener) Protocol() string {
return "mock"
}
// Close closes the mock listener.
func (l *MockListener) Close(closer CloseFn) {
l.Lock()
defer l.Unlock()
l.Serving = false
closer(l.id)
close(l.done)
}
// IsServing indicates whether the mock listener is serving.
func (l *MockListener) IsServing() bool {
l.Lock()
defer l.Unlock()
return l.Serving
}
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()
return l.Listening
}

View File

@ -1,99 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w)
require.NoError(t, err)
w.Close()
}
func TestNewMockListener(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.id)
require.Equal(t, testAddr, mocked.address)
}
func TestMockListenerID(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.ID())
}
func TestMockListenerAddress(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, testAddr, mocked.Address())
}
func TestMockListenerProtocol(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "mock", mocked.Protocol())
}
func TestNewMockListenerIsListening(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsListening())
}
func TestNewMockListenerIsServing(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsServing())
}
func TestNewMockListenerInit(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.id)
require.Equal(t, testAddr, mocked.address)
require.Equal(t, false, mocked.IsListening())
err := mocked.Init(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening())
}
func TestNewMockListenerInitFailure(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
mocked.ErrListen = true
err := mocked.Init(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsServing())
o := make(chan bool)
go func(o chan bool) {
mocked.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
require.Equal(t, true, mocked.IsServing())
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
mocked.Init(nil)
}
func TestMockListenerClose(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}

View File

@ -1,88 +0,0 @@
package listeners
import (
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// Net is a listener for establishing client connections on basic TCP protocol.
type Net struct { // [MQTT-4.2.0-1]
mu sync.Mutex
listener net.Listener // a net.Listener which will listen for new clients
id string // the internal id of the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
func NewNet(id string, listener net.Listener) *Net {
return &Net{
id: id,
listener: listener,
}
}
// ID returns the id of the listener.
func (l *Net) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Net) Address() string {
return l.listener.Addr().String()
}
// Protocol returns the network of the listener.
func (l *Net) Protocol() string {
return l.listener.Addr().Network()
}
// Init initializes the listener.
func (l *Net) Init(log *zerolog.Logger) error {
l.log = log
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *Net) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listener.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *Net) Close(closeClients CloseFn) {
l.mu.Lock()
defer l.mu.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listener != nil {
err := l.listener.Close()
if err != nil {
return
}
}
}

View File

@ -1,105 +0,0 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewNet(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.id)
}
func TestNetID(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.ID())
}
func TestNetAddress(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, n.Addr().String(), l.Address())
}
func TestNetProtocol(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "tcp", l.Protocol())
}
func TestNetInit(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
}
func TestNetServeAndClose(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestNetEstablishThenEnd(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@ -1,108 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
func NewTCP(id, address string, config *Config) *TCP {
if config == nil {
config = new(Config)
}
return &TCP{
id: id,
address: address,
config: config,
}
}
// ID returns the id of the listener.
func (l *TCP) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *TCP) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *TCP) Protocol() string {
return "tcp"
}
// Init initializes the listener.
func (l *TCP) Init(log *zerolog.Logger) error {
l.log = log
var err error
if l.config.TLSConfig != nil {
l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen("tcp", l.address)
}
return err
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *TCP) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@ -1,131 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestTCPID(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "t1", l.ID())
}
func TestTCPAddress(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestTCPProtocol(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPProtocolTLS(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
l.Init(&logger)
defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPInit(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewTCP("t2", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
require.NotNil(t, l2.config.TLSConfig)
}
func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPEstablishThenEnd(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@ -1,98 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *zerolog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@ -1,96 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@ -1,178 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var (
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
)
// Websocket is a listener for establishing websocket connections.
type Websocket struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // an http server for serving websocket connections
log *zerolog.Logger // server logger
establish EstablishFn // the server's establish connection handler
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
end uint32 // ensure the close methods are only called once
}
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string, config *Config) *Websocket {
if config == nil {
config = new(Config)
}
return &Websocket{
id: id,
address: address,
config: config,
upgrader: &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
// ID returns the id of the listener.
func (l *Websocket) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Websocket) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *Websocket) Protocol() string {
if l.config.TLSConfig != nil {
return "wss"
}
return "ws"
}
// Init initializes the listener.
func (l *Websocket) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/", l.handler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
TLSConfig: l.config.TLSConfig,
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
}
return nil
}
// handler upgrades and handles an incoming websocket connection.
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := l.upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
if err != nil {
l.log.Warn().Err(err).Send()
}
}
// Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFn) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *Websocket) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (int, error) {
op, r, err := ws.c.NextReader()
if err != nil {
return 0, err
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return 0, err
}
var n, br int
for {
br, err = r.Read(p[n:])
n += br
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
}
return n, err
}
}
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (int, error) {
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return 0, err
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
}

View File

@ -1,114 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "t1", l.ID())
}
func TestWebsocketAddress(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestWebsocketProtocol(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "ws", l.Protocol())
}
func TestWebsocketProtocoTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol())
}
func TestWebsockeInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Nil(t, l.listen)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.listen)
}
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
}
func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
e := make(chan bool)
l.establish = func(id string, c net.Conn) error {
e <- true
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
require.NoError(t, err)
require.Equal(t, true, <-e)
s.Close()
ws.Close()
}

View File

@ -1,172 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"encoding/binary"
"io"
"unicode/utf8"
"unsafe"
)
// bytesToString provides a zero-alloc no-copy byte to string conversion.
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
func bytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// decodeUint16 extracts the value of two bytes from a byte array.
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
if len(buf) < offset+2 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
}
// decodeUint32 extracts the value of four bytes from a byte array.
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
if len(buf) < offset+4 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
}
// decodeString extracts a string from a byte array, beginning at an offset.
func decodeString(buf []byte, offset int) (string, int, error) {
b, n, err := decodeBytes(buf, offset)
if err != nil {
return "", 0, err
}
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
return "", 0, ErrMalformedInvalidUTF8
}
return bytesToString(b), n, nil
}
// validUTF8 checks if the byte array contains valid UTF-8 characters.
func validUTF8(b []byte) bool {
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
}
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
length, next, err := decodeUint16(buf, offset)
if err != nil {
return make([]byte, 0), 0, err
}
if next+int(length) > len(buf) {
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
}
return buf[next : next+int(length)], next + int(length), nil
}
// decodeByte extracts the value of a byte from a byte array.
func decodeByte(buf []byte, offset int) (byte, int, error) {
if len(buf) <= offset {
return 0, 0, ErrMalformedOffsetByteOutOfRange
}
return buf[offset], offset + 1, nil
}
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
if len(buf) <= offset {
return false, 0, ErrMalformedOffsetBoolOutOfRange
}
return 1&buf[offset] > 0, offset + 1, nil
}
// encodeBool returns a byte instead of a bool.
func encodeBool(b bool) byte {
if b {
return 1
}
return 0
}
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
func encodeBytes(val []byte) []byte {
// In most circumstances the number of bytes being encoded is small.
// Setting the cap to a low amount allows us to account for those without
// triggering allocation growth on append unless we need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, val...)
}
// encodeUint16 encodes a uint16 value to a byte array.
func encodeUint16(val uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, val)
return buf
}
// encodeUint32 encodes a uint16 value to a byte array.
func encodeUint32(val uint32) []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, val)
return buf
}
// encodeString encodes a string to a byte array.
func encodeString(val string) []byte {
// Like encodeBytes, we set the cap to a small number to avoid
// triggering allocation growth on append unless we absolutely need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, []byte(val)...)
}
// encodeLength writes length bits for the header.
func encodeLength(b *bytes.Buffer, length int64) {
// 1.5.5 Variable Byte Integer encode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
for {
eb := byte(length % 128)
length /= 128
if length > 0 {
eb |= 0x80
}
b.WriteByte(eb)
if length == 0 {
break // [MQTT-1.5.5-1]
}
}
}
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
// see 1.5.5 Variable Byte Integer decode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
var multiplier uint32
var value uint32
bu = 1
for {
eb, err := b.ReadByte()
if err != nil {
return 0, bu, err
}
value |= uint32(eb&127) << multiplier
if value > 268435455 {
return 0, bu, ErrMalformedVariableByteInteger
}
if (eb & 128) == 0 {
break
}
multiplier += 7
bu++
}
return int(value), bu, nil
}

View File

@ -1,422 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"errors"
"fmt"
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result string
offset int
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: "a/b/c/d",
},
{
offset: 14,
rawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
{
offset: 17,
rawBytes: []byte{
Connect << 4, 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrMalformedInvalidUTF8,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
require.NoError(t, err)
require.Equal(t, "\ufeff", result)
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 0,
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
offset: 8,
shouldFail: ErrMalformedOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
func TestDecodeUint32(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint32
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 0, 0, 7, 8},
result: uint32(7),
offset: 0,
},
{
rawBytes: []byte{0, 0, 1, 226, 64, 8},
result: uint32(123456),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+4, offset)
})
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: ErrMalformedOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}
func TestDecodeLength(t *testing.T) {
b := bytes.NewBuffer([]byte{0x78})
n, bu, err := DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 120, n)
require.Equal(t, 1, bu)
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
n, bu, err = DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 268435455, n)
require.Equal(t, 4, bu)
}
func TestDecodeLengthErrors(t *testing.T) {
b := bytes.NewBuffer([]byte{})
_, _, err := DecodeLength(b)
require.Error(t, err)
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
_, _, err = DecodeLength(b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result)
result = encodeBool(false)
require.Equal(t, byte(0), result)
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result)
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result)
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}
func TestEncodeUint32(t *testing.T) {
result := encodeUint32(7)
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
result = encodeUint32(32767)
require.Equal(t, []byte{0, 0, 127, 255}, result)
result = encodeUint32(math.MaxUint32)
require.Equal(t, []byte{255, 255, 255, 255}, result)
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result)
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result)
}
func TestEncodeLength(t *testing.T) {
b := new(bytes.Buffer)
encodeLength(b, 120)
require.Equal(t, []byte{0x78}, b.Bytes())
b = new(bytes.Buffer)
encodeLength(b, math.MaxInt64)
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
}
func TestValidUTF8(t *testing.T) {
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
require.False(t, validUTF8([]byte{0xff, 0xff}))
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
}

View File

@ -1,147 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
type Code struct {
Reason string
Code byte
}
// String returns the readable reason for a code.
func (c Code) String() string {
return c.Reason
}
// Error returns the readable reason for a code.
func (c Code) Error() string {
return c.Reason
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
2: CodeGrantedQos2,
}
CodeSuccess = Code{Code: 0x00, Reason: "success"}
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@ -1,29 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCodesString(t *testing.T) {
c := Code{
Reason: "test",
Code: 0x1,
}
require.Equal(t, "test", c.String())
}
func TestCodesErrorr(t *testing.T) {
c := Code{
Reason: "error",
Code: 0x1,
}
require.Equal(t, "error", error(c).Error())
}

View File

@ -1,63 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte `json:"qos"` // indicates the quality of service expected.
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
Retain bool `json:"retain"` // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, int64(fh.Remaining))
}
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(hb byte) error {
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
}
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
fh.Qos = (hb >> 1) & 0x03 // qos flag
fh.Retain = hb&0x01 > 0 // is retain flag
case Pubrel:
fallthrough
case Subscribe:
fallthrough
case Unsubscribe:
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
return ErrMalformedFlags
}
fh.Qos = (hb >> 1) & 0x03
default:
if (hb>>0)&0x01 != 0 ||
(hb>>1)&0x01 != 0 ||
(hb>>2)&0x01 != 0 ||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
return ErrMalformedFlags
}
}
if fh.Qos == 0 && fh.Dup {
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
}
return nil
}

View File

@ -1,237 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
desc string
rawBytes []byte
header FixedHeader
packetError bool
expect error
}
var fixedHeaderExpected = []fixedHeaderTable{
{
desc: "connect",
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "connack",
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish",
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1",
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish qos 2",
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
desc: "publish qos 2 retain",
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 0",
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 0 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 2 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "puback",
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrec",
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrel",
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "pubcomp",
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "subscribe",
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "suback",
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "unsubscribe",
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "unsuback",
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingreq",
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingresp",
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "disconnect",
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "auth",
rawBytes: []byte{Auth << 4, 0x00},
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
desc: "remaining length 10",
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
desc: "remaining length 512",
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
desc: "remaining length 978",
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
desc: "remaining length 20202",
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
desc: "remaining length oversize",
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
desc: "invalid type dup is true",
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type qos is 1",
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type retain is true",
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid publish qos bits 1 + 2 set",
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
header: FixedHeader{Type: Publish},
expect: ErrProtocolViolationQosOutOfRange,
},
{
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
header: FixedHeader{Type: Pubrel, Qos: 1},
expect: ErrMalformedFlags,
},
{
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
header: FixedHeader{Type: Subscribe, Qos: 1},
expect: ErrMalformedFlags,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.expect == nil {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
}
})
}
}
func TestFixedHeaderDecode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.expect != nil {
require.Equal(t, wanted.expect, err)
} else {
require.NoError(t, err)
require.Equal(t, wanted.header.Type, fh.Type)
require.Equal(t, wanted.header.Dup, fh.Dup)
require.Equal(t, wanted.header.Qos, fh.Qos)
require.Equal(t, wanted.header.Retain, fh.Retain)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,505 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"testing"
"github.com/jinzhu/copier"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var packetList = []byte{
Connect,
Connack,
Publish,
Puback,
Pubrec,
Pubrel,
Pubcomp,
Subscribe,
Suback,
Unsubscribe,
Unsuback,
Pingreq,
Pingresp,
Disconnect,
Auth,
}
var pkTable = []TPacketCase{
TPacketData[Connect].Get(TConnectMqtt311),
TPacketData[Connect].Get(TConnectMqtt5),
TPacketData[Connect].Get(TConnectUserPassLWT),
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
TPacketData[Connack].Get(TConnackAcceptedNoSession),
TPacketData[Publish].Get(TPublishBasic),
TPacketData[Publish].Get(TPublishMqtt5),
TPacketData[Puback].Get(TPuback),
TPacketData[Pubrec].Get(TPubrec),
TPacketData[Pubrel].Get(TPubrel),
TPacketData[Pubcomp].Get(TPubcomp),
TPacketData[Subscribe].Get(TSubscribe),
TPacketData[Subscribe].Get(TSubscribeMqtt5),
TPacketData[Suback].Get(TSuback),
TPacketData[Unsubscribe].Get(TUnsubscribe),
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
TPacketData[Pingreq].Get(TPingreq),
TPacketData[Pingresp].Get(TPingresp),
TPacketData[Disconnect].Get(TDisconnect),
TPacketData[Disconnect].Get(TDisconnectMqtt5),
}
func TestNewPackets(t *testing.T) {
s := NewPackets()
require.NotNil(t, s.internal)
}
func TestPacketsAdd(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{})
require.Contains(t, s.internal, "cl1")
}
func TestPacketsGet(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
pk, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, "a1", pk.TopicName)
}
func TestPacketsGetAll(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
s.Add("cl3", Packet{TopicName: "a3"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestPacketsLen(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSPacketsDelete(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
p := &Subscription{
Qos: 2,
NoLocal: true,
RetainAsPublished: true,
RetainHandling: 2,
}
x := new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
p = &Subscription{
Qos: 1,
NoLocal: false,
RetainAsPublished: false,
RetainHandling: 1,
}
x = new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
}
func TestPacketEncode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !encodeTestOK(wanted) {
return
}
pk := new(Packet)
copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
buf := new(bytes.Buffer)
var err error
switch pkt {
case Connect:
err = pk.ConnectEncode(buf)
case Connack:
err = pk.ConnackEncode(buf)
case Publish:
err = pk.PublishEncode(buf)
case Puback:
err = pk.PubackEncode(buf)
case Pubrec:
err = pk.PubrecEncode(buf)
case Pubrel:
err = pk.PubrelEncode(buf)
case Pubcomp:
err = pk.PubcompEncode(buf)
case Subscribe:
err = pk.SubscribeEncode(buf)
case Suback:
err = pk.SubackEncode(buf)
case Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case Unsuback:
err = pk.UnsubackEncode(buf)
case Pingreq:
err = pk.PingreqEncode(buf)
case Pingresp:
err = pk.PingrespEncode(buf)
case Disconnect:
err = pk.DisconnectEncode(buf)
case Auth:
err = pk.AuthEncode(buf)
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
encoded := buf.Bytes()
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
if len(wanted.ActualBytes) > 0 {
wanted.RawBytes = wanted.ActualBytes
}
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
})
}
}
}
func TestPacketDecode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !decodeTestOK(wanted) {
return
}
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
}
buf := wanted.RawBytes[2:]
var err error
switch pkt {
case Connect:
err = pk.ConnectDecode(buf)
case Connack:
err = pk.ConnackDecode(buf)
case Publish:
err = pk.PublishDecode(buf)
case Puback:
err = pk.PubackDecode(buf)
case Pubrec:
err = pk.PubrecDecode(buf)
case Pubrel:
err = pk.PubrelDecode(buf)
case Pubcomp:
err = pk.PubcompDecode(buf)
case Subscribe:
err = pk.SubscribeDecode(buf)
case Suback:
err = pk.SubackDecode(buf)
case Unsubscribe:
err = pk.UnsubscribeDecode(buf)
case Unsuback:
err = pk.UnsubackDecode(buf)
case Pingreq:
err = pk.PingreqDecode(buf)
case Pingresp:
err = pk.PingrespDecode(buf)
case Disconnect:
err = pk.DisconnectDecode(buf)
case Auth:
err = pk.AuthDecode(buf)
}
if wanted.FailFirst != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
if pkt == Connect {
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
// against it unless it's a connect packet.
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
}
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
})
}
}
}
func TestValidate(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if wanted.Group == "validate" || wanted.Primary {
pk := wanted.Packet
var err error
switch pkt {
case Connect:
err = pk.ConnectValidate()
case Publish:
err = pk.PublishValidate(1024)
case Subscribe:
err = pk.SubscribeValidate()
case Unsubscribe:
err = pk.UnsubscribeValidate()
case Auth:
err = pk.AuthValidate()
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
}
}
})
}
}
}
func TestAckValidatePubrec(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoMatchingSubscribers.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicNameInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrPayloadFormatInvalid.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubrel(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubcomp(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateSuback(t *testing.T) {
for _, b := range []byte{
CodeGrantedQos0.Code,
CodeGrantedQos1.Code,
CodeGrantedQos2.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrSharedSubscriptionsNotSupported.Code,
ErrSubscriptionIdentifiersNotSupported.Code,
ErrWildcardSubscriptionsNotSupported.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateUnsuback(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoSubscriptionExisted.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestReasonCodeValidMisc(t *testing.T) {
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
require.True(t, pk.ReasonCodeValid())
}
func TestCopy(t *testing.T) {
for _, tt := range pkTable {
pkc := tt.Packet.Copy(true)
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}
func TestMergeSubscription(t *testing.T) {
sub := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 0,
RetainAsPublished: false,
NoLocal: false,
Identifier: 1,
}
sub2 := Subscription{
Filter: "a/b/d",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 2,
}
expect := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 1,
Identifiers: map[string]int{
"a/b/c": 1,
"a/b/d": 2,
},
}
require.Equal(t, expect, sub.Merge(sub2))
}

View File

@ -1,479 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"strings"
)
const (
PropPayloadFormat byte = 1
PropMessageExpiryInterval byte = 2
PropContentType byte = 3
PropResponseTopic byte = 8
PropCorrelationData byte = 9
PropSubscriptionIdentifier byte = 11
PropSessionExpiryInterval byte = 17
PropAssignedClientID byte = 18
PropServerKeepAlive byte = 19
PropAuthenticationMethod byte = 21
PropAuthenticationData byte = 22
PropRequestProblemInfo byte = 23
PropWillDelayInterval byte = 24
PropRequestResponseInfo byte = 25
PropResponseInfo byte = 26
PropServerReference byte = 28
PropReasonString byte = 31
PropReceiveMaximum byte = 33
PropTopicAliasMaximum byte = 34
PropTopicAlias byte = 35
PropMaximumQos byte = 36
PropRetainAvailable byte = 37
PropUser byte = 38
PropMaximumPacketSize byte = 39
PropWildcardSubAvailable byte = 40
PropSubIDAvailable byte = 41
PropSharedSubAvailable byte = 42
)
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
PropServerKeepAlive: {Connack: 1},
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropReceiveMaximum: {Connect: 1, Connack: 1},
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
PropSharedSubAvailable: {Connack: 1},
}
// UserProperty is an arbitrary key-value pair for a packet user properties array.
type UserProperty struct { // [MQTT-1.5.7-1]
Key string `json:"k"`
Val string `json:"v"`
}
// Properties contains all of the mqtt v5 properties available for a packet.
// Some properties have valid values of 0 or not-present. In this case, we opt for
// property flags to indicate the usage of property.
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
type Properties struct {
CorrelationData []byte `json:"cd"`
SubscriptionIdentifier []int `json:"si"`
AuthenticationData []byte `json:"ad"`
User []UserProperty `json:"user"`
ContentType string `json:"ct"`
ResponseTopic string `json:"rt"`
AssignedClientID string `json:"aci"`
AuthenticationMethod string `json:"am"`
ResponseInfo string `json:"ri"`
ServerReference string `json:"sr"`
ReasonString string `json:"rs"`
MessageExpiryInterval uint32 `json:"me"`
SessionExpiryInterval uint32 `json:"sei"`
WillDelayInterval uint32 `json:"wdi"`
MaximumPacketSize uint32 `json:"mps"`
ServerKeepAlive uint16 `json:"ska"`
ReceiveMaximum uint16 `json:"rm"`
TopicAliasMaximum uint16 `json:"tam"`
TopicAlias uint16 `json:"ta"`
PayloadFormat byte `json:"pf"`
PayloadFormatFlag bool `json:"fpf"`
SessionExpiryIntervalFlag bool `json:"fsei"`
ServerKeepAliveFlag bool `json:"fska"`
RequestProblemInfo byte `json:"rpi"`
RequestProblemInfoFlag bool `json:"frpi"`
RequestResponseInfo byte `json:"rri"`
TopicAliasFlag bool `json:"fta"`
MaximumQos byte `json:"mqos"`
MaximumQosFlag bool `json:"fmqos"`
RetainAvailable byte `json:"ra"`
RetainAvailableFlag bool `json:"fra"`
WildcardSubAvailable byte `json:"wsa"`
WildcardSubAvailableFlag bool `json:"fwsa"`
SubIDAvailable byte `json:"sida"`
SubIDAvailableFlag bool `json:"fsida"`
SharedSubAvailable byte `json:"ssa"`
SharedSubAvailableFlag bool `json:"fssa"`
}
// Copy creates a new Properties struct with copies of the values.
func (p *Properties) Copy(allowTransfer bool) Properties {
pr := Properties{
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
PayloadFormatFlag: p.PayloadFormatFlag,
MessageExpiryInterval: p.MessageExpiryInterval,
ContentType: p.ContentType, // [MQTT-3.3.2-20]
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
SessionExpiryInterval: p.SessionExpiryInterval,
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
AssignedClientID: p.AssignedClientID,
ServerKeepAlive: p.ServerKeepAlive,
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
AuthenticationMethod: p.AuthenticationMethod,
RequestProblemInfo: p.RequestProblemInfo,
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
WillDelayInterval: p.WillDelayInterval,
RequestResponseInfo: p.RequestResponseInfo,
ResponseInfo: p.ResponseInfo,
ServerReference: p.ServerReference,
ReasonString: p.ReasonString,
ReceiveMaximum: p.ReceiveMaximum,
TopicAliasMaximum: p.TopicAliasMaximum,
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
MaximumQos: p.MaximumQos,
MaximumQosFlag: p.MaximumQosFlag,
RetainAvailable: p.RetainAvailable,
RetainAvailableFlag: p.RetainAvailableFlag,
MaximumPacketSize: p.MaximumPacketSize,
WildcardSubAvailable: p.WildcardSubAvailable,
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
SubIDAvailable: p.SubIDAvailable,
SubIDAvailableFlag: p.SubIDAvailableFlag,
SharedSubAvailable: p.SharedSubAvailable,
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
}
if allowTransfer {
pr.TopicAlias = p.TopicAlias
pr.TopicAliasFlag = p.TopicAliasFlag
}
if len(p.CorrelationData) > 0 {
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
}
if len(p.SubscriptionIdentifier) > 0 {
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
}
if len(p.AuthenticationData) > 0 {
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
}
if len(p.User) > 0 {
pr.User = []UserProperty{}
for _, v := range p.User {
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
Key: v.Key,
Val: v.Val,
})
}
}
return pr
}
// canEncode returns true if the property type is valid for the packet type.
func (p *Properties) canEncode(pkt byte, k byte) bool {
return validPacketProperties[k][pkt] == 1
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
}
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
buf.WriteByte(PropMessageExpiryInterval)
buf.Write(encodeUint32(p.MessageExpiryInterval))
}
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
buf.WriteByte(PropContentType)
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
for _, v := range p.SubscriptionIdentifier {
if v > 0 {
buf.WriteByte(PropSubscriptionIdentifier)
encodeLength(&buf, int64(v))
}
}
}
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
buf.WriteByte(PropSessionExpiryInterval)
buf.Write(encodeUint32(p.SessionExpiryInterval))
}
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
buf.WriteByte(PropAssignedClientID)
buf.Write(encodeString(p.AssignedClientID))
}
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
buf.WriteByte(PropServerKeepAlive)
buf.Write(encodeUint16(p.ServerKeepAlive))
}
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
buf.WriteByte(PropAuthenticationMethod)
buf.Write(encodeString(p.AuthenticationMethod))
}
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
buf.WriteByte(PropAuthenticationData)
buf.Write(encodeBytes(p.AuthenticationData))
}
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
buf.WriteByte(PropRequestProblemInfo)
buf.WriteByte(p.RequestProblemInfo)
}
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
buf.WriteByte(PropWillDelayInterval)
buf.Write(encodeUint32(p.WillDelayInterval))
}
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
buf.WriteByte(PropRequestResponseInfo)
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
buf.WriteByte(PropServerReference)
buf.Write(encodeString(p.ServerReference))
}
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
}
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
buf.WriteByte(PropReceiveMaximum)
buf.Write(encodeUint16(p.ReceiveMaximum))
}
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
buf.WriteByte(PropTopicAliasMaximum)
buf.Write(encodeUint16(p.TopicAliasMaximum))
}
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
buf.WriteByte(PropTopicAlias)
buf.Write(encodeUint16(p.TopicAlias))
}
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
buf.WriteByte(PropMaximumQos)
buf.WriteByte(p.MaximumQos)
}
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
buf.WriteByte(PropRetainAvailable)
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
pb.Write(encodeString(v.Key))
pb.Write(encodeString(v.Val))
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
buf.Write(pb.Bytes())
}
}
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
buf.WriteByte(PropMaximumPacketSize)
buf.Write(encodeUint32(p.MaximumPacketSize))
}
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
buf.WriteByte(PropWildcardSubAvailable)
buf.WriteByte(p.WildcardSubAvailable)
}
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
buf.WriteByte(PropSubIDAvailable)
buf.WriteByte(p.SubIDAvailable)
}
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
buf.WriteByte(PropSharedSubAvailable)
buf.WriteByte(p.SharedSubAvailable)
}
encodeLength(b, int64(buf.Len()))
buf.WriteTo(b) // [MQTT-3.1.3-10]
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n + bu, err
}
if n == 0 {
return n + bu, nil
}
bt := b.Bytes()
var k byte
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
}
switch k {
case PropPayloadFormat:
p.PayloadFormat, offset, err = decodeByte(bt, offset)
p.PayloadFormatFlag = true
case PropMessageExpiryInterval:
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
case PropContentType:
p.ContentType, offset, err = decodeString(bt, offset)
case PropResponseTopic:
p.ResponseTopic, offset, err = decodeString(bt, offset)
case PropCorrelationData:
p.CorrelationData, offset, err = decodeBytes(bt, offset)
case PropSubscriptionIdentifier:
if p.SubscriptionIdentifier == nil {
p.SubscriptionIdentifier = []int{}
}
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
case PropSessionExpiryInterval:
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
p.SessionExpiryIntervalFlag = true
case PropAssignedClientID:
p.AssignedClientID, offset, err = decodeString(bt, offset)
case PropServerKeepAlive:
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
p.ServerKeepAliveFlag = true
case PropAuthenticationMethod:
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
case PropAuthenticationData:
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
case PropRequestProblemInfo:
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
p.RequestProblemInfoFlag = true
case PropWillDelayInterval:
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
case PropRequestResponseInfo:
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
case PropResponseInfo:
p.ResponseInfo, offset, err = decodeString(bt, offset)
case PropServerReference:
p.ServerReference, offset, err = decodeString(bt, offset)
case PropReasonString:
p.ReasonString, offset, err = decodeString(bt, offset)
case PropReceiveMaximum:
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAliasMaximum:
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAlias:
p.TopicAlias, offset, err = decodeUint16(bt, offset)
p.TopicAliasFlag = true
case PropMaximumQos:
p.MaximumQos, offset, err = decodeByte(bt, offset)
p.MaximumQosFlag = true
case PropRetainAvailable:
p.RetainAvailable, offset, err = decodeByte(bt, offset)
p.RetainAvailableFlag = true
case PropUser:
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
case PropMaximumPacketSize:
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
case PropWildcardSubAvailable:
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
p.WildcardSubAvailableFlag = true
case PropSubIDAvailable:
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
p.SubIDAvailableFlag = true
case PropSharedSubAvailable:
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
p.SharedSubAvailableFlag = true
}
if err != nil {
return n + bu, err
}
}
return n + bu, nil
}

View File

@ -1,333 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var (
propertiesStruct = Properties{
PayloadFormat: byte(1), // UTF-8 Format
PayloadFormatFlag: true,
MessageExpiryInterval: uint32(2),
ContentType: "text/plain",
ResponseTopic: "a/b/c",
CorrelationData: []byte("data"),
SubscriptionIdentifier: []int{322122},
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
AssignedClientID: "mochi-v5",
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AuthenticationMethod: "SHA-1",
AuthenticationData: []byte("auth-data"),
RequestProblemInfo: byte(1),
RequestProblemInfoFlag: true,
WillDelayInterval: uint32(600),
RequestResponseInfo: byte(1),
ResponseInfo: "response",
ServerReference: "mochi-2",
ReasonString: "reason",
ReceiveMaximum: uint16(500),
TopicAliasMaximum: uint16(999),
TopicAlias: uint16(3),
TopicAliasFlag: true,
MaximumQos: byte(1),
MaximumQosFlag: true,
RetainAvailable: byte(1),
RetainAvailableFlag: true,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
{
Key: "key2",
Val: "value2",
},
},
MaximumPacketSize: uint32(32000),
WildcardSubAvailable: byte(1),
WildcardSubAvailableFlag: true,
SubIDAvailable: byte(1),
SubIDAvailableFlag: true,
SharedSubAvailable: byte(1),
SharedSubAvailableFlag: true,
}
propertiesBytes = []byte{
172, 1, // VBI
// Payload Format (1) (vbi:2)
1, 1,
// Message Expiry (2) (vbi:7)
2, 0, 0, 0, 2,
// Content Type (3) (vbi:20)
3,
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
// Response Topic (8) (vbi:28)
8,
0, 5, 'a', '/', 'b', '/', 'c',
// Correlations Data (9) (vbi:35)
9,
0, 4, 'd', 'a', 't', 'a',
// Subscription Identifier (11) (vbi:39)
11,
202, 212, 19,
// Session Expiry Interval (17) (vbi:43)
17,
0, 0, 0, 120,
// Assigned Client ID (18) (vbi:55)
18,
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
// Server Keep Alive (19) (vbi:58)
19,
0, 20,
// Authentication Method (21) (vbi:66)
21,
0, 5, 'S', 'H', 'A', '-', '1',
// Authentication Data (22) (vbi:78)
22,
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
// Request Problem Info (23) (vbi:80)
23, 1,
// Will Delay Interval (24) (vbi:85)
24,
0, 0, 2, 88,
// Request Response Info (25) (vbi:87)
25, 1,
// Response Info (26) (vbi:98)
26,
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
// Server Reference (28) (vbi:108)
28,
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
// Reason String (31) (vbi:117)
31,
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
// Receive Maximum (33) (vbi:120)
33,
1, 244,
// Topic Alias Maximum (34) (vbi:123)
34,
3, 231,
// Topic Alias (35) (vbi:126)
35,
0, 3,
// Maximum Qos (36) (vbi:128)
36, 1,
// Retain Available (37) (vbi: 130)
37, 1,
// User Properties (38) (vbi:161)
38,
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
38,
0, 4, 'k', 'e', 'y', '2',
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
// Maximum Packet Size (39) (vbi:166)
39,
0, 0, 125, 0,
// Wildcard Subscriptions Available (40) (vbi:168)
40, 1,
// Subscription ID Available (41) (vbi:170)
41, 1,
// Shared Subscriptions Available (42) (vbi:172)
42, 1,
}
)
func init() {
validPacketProperties[PropPayloadFormat][Reserved] = 1
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
validPacketProperties[PropContentType][Reserved] = 1
validPacketProperties[PropResponseTopic][Reserved] = 1
validPacketProperties[PropCorrelationData][Reserved] = 1
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
validPacketProperties[PropAssignedClientID][Reserved] = 1
validPacketProperties[PropServerKeepAlive][Reserved] = 1
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
validPacketProperties[PropAuthenticationData][Reserved] = 1
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
validPacketProperties[PropWillDelayInterval][Reserved] = 1
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
validPacketProperties[PropResponseInfo][Reserved] = 1
validPacketProperties[PropServerReference][Reserved] = 1
validPacketProperties[PropReasonString][Reserved] = 1
validPacketProperties[PropReceiveMaximum][Reserved] = 1
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
validPacketProperties[PropTopicAlias][Reserved] = 1
validPacketProperties[PropMaximumQos][Reserved] = 1
validPacketProperties[PropRetainAvailable][Reserved] = 1
validPacketProperties[PropUser][Reserved] = 1
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
validPacketProperties[PropSubIDAvailable][Reserved] = 1
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
}
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
}
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
}
func TestEncodePropertiesNil(t *testing.T) {
type tmp struct {
p *Properties
}
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
func TestDecodeProperties(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}
func TestDecodePropertiesNil(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
type tmp struct {
p *Properties
}
pr := tmp{}
n, err := pr.p.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 0, n)
}
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
}
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{0})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, props, new(Properties))
}
func TestDecodePropertiesBadKeyByte(t *testing.T) {
b := bytes.NewBuffer([]byte{64, 1})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
}
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
b := bytes.NewBuffer([]byte{1, 99})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
}
func TestDecodePropertiesGeneralFailure(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadUserProps(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestCopyProperties(t *testing.T) {
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
}
func TestCopyPropertiesNoTransfer(t *testing.T) {
pkA := propertiesStruct
pkB := pkA.Copy(false)
// Properties which should never be transferred from one connection to another
require.Equal(t, uint16(0), pkB.TopicAlias)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,33 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func encodeTestOK(wanted TPacketCase) bool {
if wanted.RawBytes == nil {
return false
}
if wanted.Group != "" && wanted.Group != "encode" {
return false
}
return true
}
func decodeTestOK(wanted TPacketCase) bool {
if wanted.Group != "" && wanted.Group != "decode" {
return false
}
return true
}
func TestTPacketCaseGet(t *testing.T) {
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package system
import "sync/atomic"
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics (and others).
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
type Info struct {
Version string `json:"version"` // the current version of the server
Started int64 `json:"started"` // the time the server started in unix seconds
Time int64 `json:"time"` // current time on the server
Uptime int64 `json:"uptime"` // the number of seconds the server has been online
BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started
BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started
ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients
ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected
ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
Retained int64 `json:"retained"` // total number of retained messages active on the broker
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker
PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received
PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
}
// Clone makes a copy of Info using atomic operation
func (i *Info) Clone() *Info {
return &Info{
Version: i.Version,
Started: atomic.LoadInt64(&i.Started),
Time: atomic.LoadInt64(&i.Time),
Uptime: atomic.LoadInt64(&i.Uptime),
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
BytesSent: atomic.LoadInt64(&i.BytesSent),
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
Retained: atomic.LoadInt64(&i.Retained),
Inflight: atomic.LoadInt64(&i.Inflight),
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
Threads: atomic.LoadInt64(&i.Threads),
}
}

View File

@ -1,37 +0,0 @@
package system
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestClone(t *testing.T) {
o := &Info{
Version: "version",
Started: 1,
Time: 2,
Uptime: 3,
BytesReceived: 4,
BytesSent: 5,
ClientsConnected: 6,
ClientsMaximum: 7,
ClientsTotal: 8,
ClientsDisconnected: 9,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
Retained: 12,
Inflight: 13,
InflightDropped: 14,
Subscriptions: 15,
PacketsReceived: 16,
PacketsSent: 17,
MemoryAlloc: 18,
Threads: 19,
}
n := o.Clone()
require.Equal(t, o, n)
}

View File

@ -1,699 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"strings"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/packets"
)
var (
SharePrefix = "$SHARE" // the prefix indicating a share topic
SysPrefix = "$SYS" // the prefix indicating a system info topic
)
// TopicAliases contains inbound and outbound topic alias registrations.
type TopicAliases struct {
Inbound *InboundTopicAliases
Outbound *OutboundTopicAliases
}
// NewTopicAliases returns an instance of TopicAliases.
func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
return TopicAliases{
Inbound: NewInboundTopicAliases(topicAliasMaximum),
Outbound: NewOutboundTopicAliases(topicAliasMaximum),
}
}
// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
return &InboundTopicAliases{
maximum: topicAliasMaximum,
internal: map[uint16]string{},
}
}
// InboundTopicAliases contains a map of topic aliases received from the client.
type InboundTopicAliases struct {
internal map[uint16]string
sync.RWMutex
maximum uint16
}
// Set sets a new alias for a specific topic.
func (a *InboundTopicAliases) Set(id uint16, topic string) string {
a.Lock()
defer a.Unlock()
if a.maximum == 0 {
return topic // ?
}
if existing, ok := a.internal[id]; ok && topic == "" {
return existing
}
a.internal[id] = topic
return topic
}
// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
type OutboundTopicAliases struct {
internal map[string]uint16
sync.RWMutex
cursor uint32
maximum uint16
}
// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
return &OutboundTopicAliases{
maximum: topicAliasMaximum,
internal: map[string]uint16{},
}
}
// Set sets a new topic alias for a topic and returns the alias value, and a boolean
// indicating if the alias already existed.
func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
a.Lock()
defer a.Unlock()
if a.maximum == 0 {
return 0, false
}
if i, ok := a.internal[topic]; ok {
return i, true
}
i := atomic.LoadUint32(&a.cursor)
if i+1 > uint32(a.maximum) {
// if i+1 > math.MaxUint16 {
return 0, false
}
a.internal[topic] = uint16(i) + 1
atomic.StoreUint32(&a.cursor, i+1)
return uint16(i) + 1, false
}
// SharedSubscriptions contains a map of subscriptions to a shared filter,
// keyed on share group then client id.
type SharedSubscriptions struct {
internal map[string]map[string]packets.Subscription
sync.RWMutex
}
// NewSharedSubscriptions returns a new instance of Subscriptions.
func NewSharedSubscriptions() *SharedSubscriptions {
return &SharedSubscriptions{
internal: map[string]map[string]packets.Subscription{},
}
}
// Add creates a new shared subscription for a group and client id pair.
func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
s.Lock()
defer s.Unlock()
if _, ok := s.internal[group]; !ok {
s.internal[group] = map[string]packets.Subscription{}
}
s.internal[group][id] = val
}
// Delete deletes a client id from a shared subscription group.
func (s *SharedSubscriptions) Delete(group, id string) {
s.Lock()
defer s.Unlock()
delete(s.internal[group], id)
if len(s.internal[group]) == 0 {
delete(s.internal, group)
}
}
// Get returns the subscription properties for a client id in a share group, if one exists.
func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
s.RLock()
defer s.RUnlock()
if _, ok := s.internal[group]; !ok {
return val, ok
}
val, ok = s.internal[group][id]
return val, ok
}
// GroupLen returns the number of groups subscribed to the filter.
func (s *SharedSubscriptions) GroupLen() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Len returns the total number of shared subscriptions to a filter across all groups.
func (s *SharedSubscriptions) Len() int {
s.RLock()
defer s.RUnlock()
n := 0
for _, group := range s.internal {
n += len(group)
}
return n
}
// GetAll returns all shared subscription groups and their subscriptions.
func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
s.RLock()
defer s.RUnlock()
m := map[string]map[string]packets.Subscription{}
for group, subs := range s.internal {
if _, ok := m[group]; !ok {
m[group] = map[string]packets.Subscription{}
}
for id, sub := range subs {
m[group][id] = sub
}
}
return m
}
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions struct {
internal map[string]packets.Subscription
sync.RWMutex
}
// NewSubscriptions returns a new instance of Subscriptions.
func NewSubscriptions() *Subscriptions {
return &Subscriptions{
internal: map[string]packets.Subscription{},
}
}
// Add adds a new subscription for a client. ID can be a filter in the
// case this map is client state, or a client id if particle state.
func (s *Subscriptions) Add(id string, val packets.Subscription) {
s.Lock()
defer s.Unlock()
s.internal[id] = val
}
// GetAll returns all subscriptions.
func (s *Subscriptions) GetAll() map[string]packets.Subscription {
s.RLock()
defer s.RUnlock()
m := map[string]packets.Subscription{}
for k, v := range s.internal {
m[k] = v
}
return m
}
// Get returns a subscriptions for a specific client or filter id.
func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
s.RLock()
defer s.RUnlock()
val, ok = s.internal[id]
return val, ok
}
// Len returns the number of subscriptions.
func (s *Subscriptions) Len() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Delete removes a subscription by client or filter id.
func (s *Subscriptions) Delete(id string) {
s.Lock()
defer s.Unlock()
delete(s.internal, id)
}
// ClientSubscriptions is a map of aggregated subscriptions for a client.
type ClientSubscriptions map[string]packets.Subscription
// Subscribers contains the shared and non-shared subscribers matching a topic.
type Subscribers struct {
Shared map[string]map[string]packets.Subscription
SharedSelected map[string]packets.Subscription
Subscriptions map[string]packets.Subscription
}
// SelectShared returns one subscriber for each shared subscription group.
func (s *Subscribers) SelectShared() {
s.SharedSelected = map[string]packets.Subscription{}
for _, subs := range s.Shared {
for client, sub := range subs {
cls, ok := s.SharedSelected[client]
if !ok {
cls = sub
}
s.SharedSelected[client] = cls.Merge(sub)
break
}
}
}
// MergeSharedSelected merges the selected subscribers for a shared subscription group
// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
// due to have both types of subscription matching the same filter.
func (s *Subscribers) MergeSharedSelected() {
for client, sub := range s.SharedSelected {
cls, ok := s.Subscriptions[client]
if !ok {
cls = sub
}
s.Subscriptions[client] = cls.Merge(sub)
}
}
// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
type TopicsIndex struct {
Retained *packets.Packets
root *particle // a leaf containing a message and more leaves.
}
// NewTopicsIndex returns a pointer to a new instance of Index.
func NewTopicsIndex() *TopicsIndex {
return &TopicsIndex{
Retained: packets.NewPackets(),
root: &particle{
particles: newParticles(),
subscriptions: NewSubscriptions(),
},
}
}
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
group, _ := isolateParticle(subscription.Filter, 1)
n := x.set(subscription.Filter, 2)
_, existed = n.shared.Get(group, client)
n.shared.Add(group, client, subscription)
} else {
n := x.set(subscription.Filter, 0)
_, existed = n.subscriptions.Get(client)
n.subscriptions.Add(client, subscription)
}
return !existed
}
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
if strings.HasPrefix(filter, SharePrefix) {
d = 2
}
particle := x.seek(filter, d)
if particle == nil {
return false
}
prefix, _ := isolateParticle(filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
group, _ := isolateParticle(filter, 1)
particle.shared.Delete(group, client)
} else {
particle.subscriptions.Delete(client)
}
x.trim(particle)
return true
}
// RetainMessage saves a message payload to the end of a topic address. Returns
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
return 1
}
var out int64
if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
out = -1 // if a retained packet existed, return -1
}
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
// set creates a topic address in the index and returns the final particle.
func (x *TopicsIndex) set(topic string, d int) *particle {
var key string
var hasNext = true
n := x.root
for hasNext {
key, hasNext = isolateParticle(topic, d)
d++
p := n.particles.get(key)
if p == nil {
p = newParticle(key, n)
n.particles.add(p)
}
n = p
}
return n
}
// seek finds the particle at a specific index in a topic filter.
func (x *TopicsIndex) seek(filter string, d int) *particle {
var key string
var hasNext = true
n := x.root
for hasNext {
key, hasNext = isolateParticle(filter, d)
n = n.particles.get(key)
d++
if n == nil {
return nil
}
}
return n
}
// trim removes empty filter particles from the index.
func (x *TopicsIndex) trim(n *particle) {
for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len() == 0 {
key := n.key
n = n.parent
n.particles.delete(key)
}
}
// Messages returns a slice of any retained messages which match a filter.
func (x *TopicsIndex) Messages(filter string) []packets.Packet {
return x.scanMessages(filter, 0, nil, []packets.Packet{})
}
// scanMessages returns all retained messages on topics matching a given filter.
func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
if n == nil {
n = x.root
}
if len(filter) == 0 || x.Retained.Len() == 0 {
return pks
}
if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
if pk, ok := x.Retained.Get(filter); ok {
pks = append(pks, pk)
}
return pks
}
key, hasNext := isolateParticle(filter, d)
if key == "+" || key == "#" || d == -1 {
for _, adjacent := range n.particles.getAll() {
if d == 0 && adjacent.key == SysPrefix {
continue
}
if !hasNext {
if adjacent.retainPath != "" {
if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
pks = append(pks, pk)
}
}
}
if hasNext || (d >= 0 && key == "#") {
pks = x.scanMessages(filter, d+1, adjacent, pks)
}
}
return pks
}
if particle := n.particles.get(key); particle != nil {
if hasNext {
return x.scanMessages(filter, d+1, particle, pks)
}
if pk, ok := x.Retained.Get(particle.retainPath); ok {
pks = append(pks, pk)
}
}
return pks
}
// Subscribers returns a map of clients who are subscribed to matching filters,
// their subscription ids and highest qos.
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
return x.scanSubscribers(topic, 0, nil, &Subscribers{
Shared: map[string]map[string]packets.Subscription{},
SharedSelected: map[string]packets.Subscription{},
Subscriptions: map[string]packets.Subscription{},
})
}
// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
if n == nil {
n = x.root
}
if len(topic) == 0 {
return subs
}
key, hasNext := isolateParticle(topic, d)
for _, partKey := range []string{key, "+", "#"} {
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
}
if hasNext {
x.scanSubscribers(topic, d+1, particle, subs)
}
}
}
return subs
}
// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
if subs.Subscriptions == nil {
subs.Subscriptions = map[string]packets.Subscription{}
}
for client, sub := range particle.subscriptions.GetAll() {
if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
continue
}
cls, ok := subs.Subscriptions[client]
if !ok {
cls = sub
}
subs.Subscriptions[client] = cls.Merge(sub)
}
}
// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
if subs.Shared == nil {
subs.Shared = map[string]map[string]packets.Subscription{}
}
for _, shares := range particle.shared.GetAll() {
for client, sub := range shares {
if _, ok := subs.Shared[sub.Filter]; !ok {
subs.Shared[sub.Filter] = map[string]packets.Subscription{}
}
subs.Shared[sub.Filter][client] = sub
}
}
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
for i := 0; end > -1 && i <= d; i++ {
end = strings.IndexRune(filter, '/')
switch {
case d > -1 && i == d && end > -1:
hasNext = true
particle = filter[next:end]
case end > -1:
hasNext = false
filter = filter[end+1:]
default:
hasNext = false
particle = filter[next:]
}
}
return
}
// IsSharedFilter returns true if the filter uses the share prefix.
func IsSharedFilter(filter string) bool {
prefix, _ := isolateParticle(filter, 0)
return strings.EqualFold(prefix, SharePrefix)
}
// IsValidFilter returns true if the filter is valid.
func IsValidFilter(filter string, forPublish bool) bool {
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
return false // [MQTT-4.7.3-1]
}
if forPublish {
if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
// 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
return false
}
if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
return false //[MQTT-3.3.2-2]
}
}
wildhash := strings.IndexRune(filter, '#')
if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
return false
}
prefix, hasNext := isolateParticle(filter, 0)
if !hasNext && strings.EqualFold(prefix, SharePrefix) {
return false // [MQTT-4.8.2-1]
}
if hasNext && strings.EqualFold(prefix, SharePrefix) {
group, hasNext := isolateParticle(filter, 1)
if !hasNext {
return false // [MQTT-4.8.2-1]
}
if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
return false // [MQTT-4.8.2-2]
}
}
return true
}
// particle is a child node on the tree.
type particle struct {
key string // the key of the particle
parent *particle // a pointer to the parent of the particle
particles particles // a map of child particles
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.
func newParticle(key string, parent *particle) *particle {
return &particle{
key: key,
parent: parent,
particles: newParticles(),
subscriptions: NewSubscriptions(),
shared: NewSharedSubscriptions(),
}
}
// particles is a concurrency safe map of particles.
type particles struct {
internal map[string]*particle
sync.RWMutex
}
// newParticles returns a map of particles.
func newParticles() particles {
return particles{
internal: map[string]*particle{},
}
}
// add adds a new particle.
func (p *particles) add(val *particle) {
p.Lock()
p.internal[val.key] = val
p.Unlock()
}
// getAll returns all particles.
func (p *particles) getAll() map[string]*particle {
p.RLock()
defer p.RUnlock()
m := map[string]*particle{}
for k, v := range p.internal {
m[k] = v
}
return m
}
// get returns a particle by id (key).
func (p *particles) get(id string) *particle {
p.RLock()
defer p.RUnlock()
return p.internal[id]
}
// len returns the number of particles.
func (p *particles) len() int {
p.RLock()
defer p.RUnlock()
val := len(p.internal)
return val
}
// delete removes a particle.
func (p *particles) delete(id string) {
p.Lock()
defer p.Unlock()
delete(p.internal, id)
}

View File

@ -1,842 +0,0 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
const (
testGroup = "testgroup"
otherGroup = "other"
)
func TestNewSharedSubscriptions(t *testing.T) {
s := NewSharedSubscriptions()
require.NotNil(t, s.internal)
}
func TestSharedSubscriptionsAdd(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
}
func TestSharedSubscriptionsGet(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 2})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
sub, ok := s.Get(testGroup, "cl2")
require.Equal(t, true, ok)
require.Equal(t, byte(2), sub.Qos)
}
func TestSharedSubscriptionsGetAll(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
require.Contains(t, s.internal, otherGroup)
require.Contains(t, s.internal[otherGroup], "cl3")
subs := s.GetAll()
require.Len(t, subs, 2)
require.Len(t, subs[testGroup], 2)
require.Len(t, subs[otherGroup], 1)
}
func TestSharedSubscriptionsLen(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
require.Contains(t, s.internal, otherGroup)
require.Contains(t, s.internal[otherGroup], "cl2")
require.Equal(t, 3, s.Len())
require.Equal(t, 2, s.GroupLen())
}
func TestSharedSubscriptionsDelete(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 1})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
require.Equal(t, 2, s.Len())
s.Delete(testGroup, "cl1")
_, ok := s.Get(testGroup, "cl1")
require.False(t, ok)
require.Equal(t, 1, s.GroupLen())
require.Equal(t, 1, s.Len())
s.Delete(testGroup, "cl2")
_, ok = s.Get(testGroup, "cl2")
require.False(t, ok)
require.Equal(t, 0, s.GroupLen())
require.Equal(t, 0, s.Len())
}
func TestNewSubscriptions(t *testing.T) {
s := NewSubscriptions()
require.NotNil(t, s.internal)
}
func TestSubscriptionsAdd(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{})
require.Contains(t, s.internal, "cl1")
}
func TestSubscriptionsGet(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 2})
s.Add("cl2", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
sub, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, byte(2), sub.Qos)
}
func TestSubscriptionsGetAll(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 0})
s.Add("cl2", packets.Subscription{Qos: 1})
s.Add("cl3", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestSubscriptionsLen(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 0})
s.Add("cl2", packets.Subscription{Qos: 1})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSubscriptionsDelete(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 1})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestNewTopicsIndex(t *testing.T) {
index := NewTopicsIndex()
require.NotNil(t, index)
require.NotNil(t, index.root)
}
func BenchmarkNewTopicsIndex(b *testing.B) {
for n := 0; n < b.N; n++ {
NewTopicsIndex()
}
}
func TestSubscribe(t *testing.T) {
tt := []struct {
desc string
client string
filter string
subscription packets.Subscription
wasNew bool
}{
{
desc: "subscribe",
client: "cl1",
subscription: packets.Subscription{Filter: "a/b/c", Qos: 2},
wasNew: true,
},
{
desc: "subscribe existed",
client: "cl1",
subscription: packets.Subscription{Filter: "a/b/c", Qos: 1},
wasNew: false,
},
{
desc: "subscribe case sensitive didnt exist",
client: "cl1",
subscription: packets.Subscription{Filter: "A/B/c", Qos: 1},
wasNew: true,
},
{
desc: "wildcard+ sub",
client: "cl1",
subscription: packets.Subscription{Filter: "d/+"},
wasNew: true,
},
{
desc: "wildcard# sub",
client: "cl1",
subscription: packets.Subscription{Filter: "d/e/#"},
wasNew: true,
},
}
index := NewTopicsIndex()
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription))
})
}
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
client, exists := final.subscriptions.Get("cl1")
require.True(t, exists)
require.Equal(t, byte(1), client.Qos)
}
func TestSubscribeShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2})
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
client, exists := final.shared.Get("tmp", "cl1")
require.True(t, exists)
require.Equal(t, byte(2), client.Qos)
require.Equal(t, 0, final.subscriptions.Len())
require.Equal(t, 1, final.shared.Len())
}
func BenchmarkSubscribe(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"})
}
}
func BenchmarkSubscribeShared(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"})
}
}
func TestUnsubscribe(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1})
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1})
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1})
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1})
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2})
client, exists = index.root.particles.get("#").subscriptions.Get("cl3")
require.NotNil(t, client)
require.True(t, exists)
ok := index.Unsubscribe("a/b/c/d", "cl1")
require.True(t, ok)
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
ok = index.Unsubscribe("d/e/f", "cl1")
require.True(t, ok)
require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len())
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
require.NotNil(t, client)
require.True(t, exists)
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody")
require.False(t, ok)
}
func TestUnsubscribeNoCascade(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"})
ok := index.Unsubscribe("a/b/c/e/e", "cl1")
require.True(t, ok)
require.Equal(t, 1, index.root.particles.len())
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
}
func TestUnsubscribeShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2})
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
client, exists := final.shared.Get("tmp", "cl1")
require.True(t, exists)
require.Equal(t, byte(2), client.Qos)
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
_, exists = final.shared.Get("tmp", "cl1")
require.False(t, exists)
}
func BenchmarkUnsubscribe(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
b.StopTimer()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
b.StartTimer()
index.Unsubscribe("a/b/c", "cl1")
}
}
func TestIndexSeek(t *testing.T) {
filter := "a/b/c/d/e/f"
index := NewTopicsIndex()
k1 := index.set(filter, 0)
require.Equal(t, "f", k1.key)
k1.subscriptions.Add("cl1", packets.Subscription{})
require.Equal(t, k1, index.seek(filter, 0))
require.Nil(t, index.seek("d/e/f", 0))
}
func TestIndexTrim(t *testing.T) {
index := NewTopicsIndex()
k1 := index.set("a/b/c", 0)
require.Equal(t, "c", k1.key)
k1.subscriptions.Add("cl1", packets.Subscription{})
k2 := index.set("a/b/c/d/e/f", 0)
require.Equal(t, "f", k2.key)
k2.subscriptions.Add("cl1", packets.Subscription{})
k3 := index.set("a/b", 0)
require.Equal(t, "b", k3.key)
k3.subscriptions.Add("cl1", packets.Subscription{})
index.trim(k2)
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f"))
require.NotNil(t, index.root.particles.get("a").particles.get("b"))
k2.subscriptions.Delete("cl1")
index.trim(k2)
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d"))
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
k1.subscriptions.Delete("cl1")
k3.subscriptions.Delete("cl1")
index.trim(k2)
require.Nil(t, index.root.particles.get("a"))
}
func TestIndexSet(t *testing.T) {
index := NewTopicsIndex()
child := index.set("a/b/c", 0)
require.Equal(t, "c", child.key)
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
child = index.set("a/b/c/d/e", 0)
require.Equal(t, "e", child.key)
child = index.set("a/b/c/c/a", 0)
require.Equal(t, "a", child.key)
}
func TestIndexSetPrefixed(t *testing.T) {
index := NewTopicsIndex()
child := index.set("/c", 0)
require.Equal(t, "c", child.key)
require.NotNil(t, index.root.particles.get("").particles.get("c"))
}
func BenchmarkIndexSet(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.set("a/b/c", 0)
}
}
func TestRetainMessage(t *testing.T) {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{Retain: true},
TopicName: "a/b/c",
Payload: []byte("hello"),
}
index := NewTopicsIndex()
r := index.RetainMessage(pk)
require.Equal(t, int64(1), r)
pke, ok := index.Retained.Get(pk.TopicName)
require.True(t, ok)
require.Equal(t, pk, pke)
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{Retain: true},
TopicName: "a/b/d/f",
Payload: []byte("hello"),
}
r = index.RetainMessage(pk2)
require.Equal(t, int64(1), r)
// The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message.
r = index.RetainMessage(pk2)
require.Equal(t, int64(1), r)
// Clear existing retained
pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}}
r = index.RetainMessage(pk3)
require.Equal(t, int64(-1), r)
_, ok = index.Retained.Get(pk.TopicName)
require.False(t, ok)
// Clear no retained
r = index.RetainMessage(pk3)
require.Equal(t, int64(0), r)
}
func BenchmarkRetainMessage(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
}
}
func TestIsolateParticle(t *testing.T) {
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
require.Equal(t, "to", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
require.Equal(t, "my", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
require.Equal(t, "mqtt", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("/path/", 0)
require.Equal(t, "", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 1)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 2)
require.Equal(t, "", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
require.Equal(t, "+", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
require.Equal(t, "+", particle)
require.Equal(t, false, hasNext)
}
func BenchmarkIsolateParticle(b *testing.B) {
for n := 0; n < b.N; n++ {
isolateParticle("path/to/my/mqtt", 3)
}
}
func TestScanSubscribers(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22})
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"})
index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"})
index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"})
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77})
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237})
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3})
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234})
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 4, len(subs.Subscriptions))
require.Contains(t, subs.Subscriptions, "cl1")
require.Contains(t, subs.Subscriptions, "cl2")
require.Contains(t, subs.Subscriptions, "cl3")
require.Contains(t, subs.Subscriptions, "cl4")
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
require.Equal(t, 0, len(subs.Subscriptions))
}
func TestScanSubscribersShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 3, len(subs.Shared))
}
func TestSelectSharedSubscriber(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110})
index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 2, len(subs.Shared))
require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c")
require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c")
require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3)
require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1)
subs.SelectShared()
require.Len(t, subs.SharedSelected, 2)
}
func TestMergeSharedSelected(t *testing.T) {
s := &Subscribers{
SharedSelected: map[string]packets.Subscription{
"cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110},
"cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111},
},
Subscriptions: map[string]packets.Subscription{
"cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112},
},
}
s.MergeSharedSelected()
require.Equal(t, 2, len(s.Subscriptions))
require.Contains(t, s.Subscriptions, "cl1")
require.Contains(t, s.Subscriptions, "cl2")
require.EqualValues(t, map[string]int{
SharePrefix + "/tmp2/a/b/c": 111,
"a/b/c": 112,
}, s.Subscriptions["cl2"].Identifiers)
}
func TestSubscribersFind(t *testing.T) {
tt := []struct {
filter string
topic string
matched bool
}{
{filter: "a", topic: "a", matched: true},
{filter: "a/", topic: "a", matched: false},
{filter: "a/", topic: "a/", matched: true},
{filter: "/a", topic: "/a", matched: true},
{filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true},
{filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
{filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
{filter: "#", topic: "path/to/my/mqtt", matched: true},
{filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true},
{filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true},
{filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2
{filter: "trailing-end/#", topic: "trailing-end/", matched: true},
{filter: "+/prefixed", topic: "/prefixed", matched: true},
{filter: "+/+/#", topic: "path/to/my/mqtt", matched: true},
{filter: "path/to/", topic: "path/to/my/mqtt", matched: false},
{filter: "#/stuff", topic: "path/to/my/mqtt", matched: false},
{filter: "#", topic: "$SYS/info", matched: false},
{filter: "$SYS/#", topic: "$SYS/info", matched: true},
{filter: "+/info", topic: "$SYS/info", matched: false},
}
for _, tx := range tt {
t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: tx.filter})
subs := index.Subscribers(tx.topic)
require.Equal(t, tx.matched, len(subs.Subscriptions) == 1)
})
}
}
func BenchmarkSubscribers(b *testing.B) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"})
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"})
index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"})
index.Subscribe("cl3", packets.Subscription{Filter: "#"})
for n := 0; n < b.N; n++ {
index.Subscribers("a/b/c")
}
}
func TestMessagesPattern(t *testing.T) {
payload := []byte("hello")
fh := packets.FixedHeader{Type: packets.Publish, Retain: true}
pks := []packets.Packet{
{TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh},
{TopicName: "$SYS/info", Payload: payload, FixedHeader: fh},
{TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh},
{TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh},
{TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh},
{TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh},
{TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh},
{TopicName: "asdf", Payload: payload, FixedHeader: fh},
}
tt := []struct {
filter string
len int
}{
{"a/b/c/d", 1},
{"$SYS/+", 2},
{"$SYS/#", 2},
{"#", len(pks) - 2},
{"a/b/c/+", 2},
{"a/+/c/+", 2},
{"+/+/+/d", 1},
{"q/w/e/#", 1},
{"+/+/+/+", 3},
{"q/#", 2},
{"asdf", 1},
{"", 0},
{"#", 6},
}
index := NewTopicsIndex()
for _, pk := range pks {
index.RetainMessage(pk)
}
for _, tx := range tt {
t.Run("filter:'"+tx.filter, func(t *testing.T) {
messages := index.Messages(tx.filter)
require.Equal(t, tx.len, len(messages))
})
}
}
func BenchmarkMessages(b *testing.B) {
index := NewTopicsIndex()
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"})
index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"})
index.RetainMessage(packets.Packet{TopicName: "$SYS/info"})
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
for n := 0; n < b.N; n++ {
index.Messages("+/b/c/+")
}
}
func TestNewParticles(t *testing.T) {
cl := newParticles()
require.NotNil(t, cl.internal)
}
func TestParticlesAdd(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
require.Contains(t, p.internal, "a")
}
func TestParticlesGet(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
p.add(&particle{key: "b"})
require.Contains(t, p.internal, "a")
require.Contains(t, p.internal, "b")
particle := p.get("a")
require.NotNil(t, particle)
require.Equal(t, "a", particle.key)
}
func TestParticlesGetAll(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
p.add(&particle{key: "b"})
p.add(&particle{key: "c"})
require.Contains(t, p.internal, "a")
require.Contains(t, p.internal, "b")
require.Contains(t, p.internal, "c")
particles := p.getAll()
require.Len(t, particles, 3)
}
func TestParticlesLen(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
p.add(&particle{key: "b"})
require.Contains(t, p.internal, "a")
require.Contains(t, p.internal, "b")
require.Equal(t, 2, p.len())
}
func TestParticlesDelete(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
require.Contains(t, p.internal, "a")
p.delete("a")
particle := p.get("a")
require.Nil(t, particle)
}
func TestIsValid(t *testing.T) {
require.True(t, IsValidFilter("a/b/c", false))
require.True(t, IsValidFilter("a/b//c", false))
require.True(t, IsValidFilter("$SYS", false))
require.True(t, IsValidFilter("$SYS/info", false))
require.True(t, IsValidFilter("$sys/info", false))
require.True(t, IsValidFilter("abc/#", false))
require.False(t, IsValidFilter("", false))
require.False(t, IsValidFilter(SharePrefix, false))
require.False(t, IsValidFilter(SharePrefix+"/", false))
require.False(t, IsValidFilter(SharePrefix+"/b+/", false))
require.False(t, IsValidFilter(SharePrefix+"/+", false))
require.False(t, IsValidFilter(SharePrefix+"/#", false))
require.False(t, IsValidFilter(SharePrefix+"/#/", false))
require.False(t, IsValidFilter("a/#/c", false))
}
func TestIsValidForPublish(t *testing.T) {
require.True(t, IsValidFilter("", true))
require.True(t, IsValidFilter("a/b/c", true))
require.False(t, IsValidFilter("a/b/+/d", true))
require.False(t, IsValidFilter("a/b/#", true))
require.False(t, IsValidFilter("$SYS/info", true))
}
func TestIsSharedFilter(t *testing.T) {
require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c"))
require.False(t, IsSharedFilter("a/b/c"))
}
func TestNewInboundAliases(t *testing.T) {
a := NewInboundTopicAliases(5)
require.NotNil(t, a)
require.NotNil(t, a.internal)
require.Equal(t, uint16(5), a.maximum)
}
func TestInboundAliasesSet(t *testing.T) {
topic := "test"
id := uint16(1)
a := NewInboundTopicAliases(5)
require.Equal(t, topic, a.Set(id, topic))
require.Contains(t, a.internal, id)
require.Equal(t, a.internal[id], topic)
require.Equal(t, topic, a.Set(id, ""))
}
func TestInboundAliasesSetMaxZero(t *testing.T) {
topic := "test"
id := uint16(1)
a := NewInboundTopicAliases(0)
require.Equal(t, topic, a.Set(id, topic))
require.NotContains(t, a.internal, id)
}
func TestNewOutboundAliases(t *testing.T) {
a := NewOutboundTopicAliases(5)
require.NotNil(t, a)
require.NotNil(t, a.internal)
require.Equal(t, uint16(5), a.maximum)
require.Equal(t, uint32(0), a.cursor)
}
func TestOutboundAliasesSet(t *testing.T) {
a := NewOutboundTopicAliases(3)
n, ok := a.Set("t1")
require.False(t, ok)
require.Equal(t, uint16(1), n)
n, ok = a.Set("t2")
require.False(t, ok)
require.Equal(t, uint16(2), n)
n, ok = a.Set("t3")
require.False(t, ok)
require.Equal(t, uint16(3), n)
n, ok = a.Set("t4")
require.False(t, ok)
require.Equal(t, uint16(0), n)
n, ok = a.Set("t2")
require.True(t, ok)
require.Equal(t, uint16(2), n)
}
func TestOutboundAliasesSetMaxZero(t *testing.T) {
topic := "test"
a := NewOutboundTopicAliases(0)
n, ok := a.Set(topic)
require.False(t, ok)
require.Equal(t, uint16(0), n)
}
func TestNewTopicAliases(t *testing.T) {
a := NewTopicAliases(5)
require.NotNil(t, a.Inbound)
require.Equal(t, uint16(5), a.Inbound.maximum)
require.NotNil(t, a.Outbound)
require.Equal(t, uint16(5), a.Outbound.maximum)
}

View File

@ -1 +0,0 @@
language: go

View File

@ -1,35 +0,0 @@
bbloom.go
// The MIT License (MIT)
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
siphash.go
// https://github.com/dchest/siphash
//
// Written in 2012 by Dmitry Chestnykh.
//
// To the extent possible under law, the author have dedicated all copyright
// and related and neighboring rights to this software to the public domain
// worldwide. This software is distributed without any warranty.
// http://creativecommons.org/publicdomain/zero/1.0/
//
// Package siphash implements SipHash-2-4, a fast short-input PRF
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.

View File

@ -1,131 +0,0 @@
## bbloom: a bitset Bloom filter for go/golang
===
[![Build Status](https://travis-ci.org/AndreasBriese/bbloom.png?branch=master)](http://travis-ci.org/AndreasBriese/bbloom)
package implements a fast bloom filter with real 'bitset' and JSONMarshal/JSONUnmarshal to store/reload the Bloom filter.
NOTE: the package uses unsafe.Pointer to set and read the bits from the bitset. If you're uncomfortable with using the unsafe package, please consider using my bloom filter package at github.com/AndreasBriese/bloom
===
changelog 11/2015: new thread safe methods AddTS(), HasTS(), AddIfNotHasTS() following a suggestion from Srdjan Marinovic (github @a-little-srdjan), who used this to code a bloomfilter cache.
This bloom filter was developed to strengthen a website-log database and was tested and optimized for this log-entry mask: "2014/%02i/%02i %02i:%02i:%02i /info.html".
Nonetheless bbloom should work with any other form of entries.
~~Hash function is a modified Berkeley DB sdbm hash (to optimize for smaller strings). sdbm http://www.cse.yorku.ca/~oz/hash.html~~
Found sipHash (SipHash-2-4, a fast short-input PRF created by Jean-Philippe Aumasson and Daniel J. Bernstein.) to be about as fast. sipHash had been ported by Dimtry Chestnyk to Go (github.com/dchest/siphash )
Minimum hashset size is: 512 ([4]uint64; will be set automatically).
###install
```sh
go get github.com/AndreasBriese/bbloom
```
###test
+ change to folder ../bbloom
+ create wordlist in file "words.txt" (you might use `python permut.py`)
+ run 'go test -bench=.' within the folder
```go
go test -bench=.
```
~~If you've installed the GOCONVEY TDD-framework http://goconvey.co/ you can run the tests automatically.~~
using go's testing framework now (have in mind that the op timing is related to 65536 operations of Add, Has, AddIfNotHas respectively)
### usage
after installation add
```go
import (
...
"github.com/AndreasBriese/bbloom"
...
)
```
at your header. In the program use
```go
// create a bloom filter for 65536 items and 1 % wrong-positive ratio
bf := bbloom.New(float64(1<<16), float64(0.01))
// or
// create a bloom filter with 650000 for 65536 items and 7 locs per hash explicitly
// bf = bbloom.New(float64(650000), float64(7))
// or
bf = bbloom.New(650000.0, 7.0)
// add one item
bf.Add([]byte("butter"))
// Number of elements added is exposed now
// Note: ElemNum will not be included in JSON export (for compatability to older version)
nOfElementsInFilter := bf.ElemNum
// check if item is in the filter
isIn := bf.Has([]byte("butter")) // should be true
isNotIn := bf.Has([]byte("Butter")) // should be false
// 'add only if item is new' to the bloomfilter
added := bf.AddIfNotHas([]byte("butter")) // should be false because 'butter' is already in the set
added = bf.AddIfNotHas([]byte("buTTer")) // should be true because 'buTTer' is new
// thread safe versions for concurrent use: AddTS, HasTS, AddIfNotHasTS
// add one item
bf.AddTS([]byte("peanutbutter"))
// check if item is in the filter
isIn = bf.HasTS([]byte("peanutbutter")) // should be true
isNotIn = bf.HasTS([]byte("peanutButter")) // should be false
// 'add only if item is new' to the bloomfilter
added = bf.AddIfNotHasTS([]byte("butter")) // should be false because 'peanutbutter' is already in the set
added = bf.AddIfNotHasTS([]byte("peanutbuTTer")) // should be true because 'penutbuTTer' is new
// convert to JSON ([]byte)
Json := bf.JSONMarshal()
// bloomfilters Mutex is exposed for external un-/locking
// i.e. mutex lock while doing JSON conversion
bf.Mtx.Lock()
Json = bf.JSONMarshal()
bf.Mtx.Unlock()
// restore a bloom filter from storage
bfNew := bbloom.JSONUnmarshal(Json)
isInNew := bfNew.Has([]byte("butter")) // should be true
isNotInNew := bfNew.Has([]byte("Butter")) // should be false
```
to work with the bloom filter.
### why 'fast'?
It's about 3 times faster than William Fitzgeralds bitset bloom filter https://github.com/willf/bloom . And it is about so fast as my []bool set variant for Boom filters (see https://github.com/AndreasBriese/bloom ) but having a 8times smaller memory footprint:
Bloom filter (filter size 524288, 7 hashlocs)
github.com/AndreasBriese/bbloom 'Add' 65536 items (10 repetitions): 6595800 ns (100 ns/op)
github.com/AndreasBriese/bbloom 'Has' 65536 items (10 repetitions): 5986600 ns (91 ns/op)
github.com/AndreasBriese/bloom 'Add' 65536 items (10 repetitions): 6304684 ns (96 ns/op)
github.com/AndreasBriese/bloom 'Has' 65536 items (10 repetitions): 6568663 ns (100 ns/op)
github.com/willf/bloom 'Add' 65536 items (10 repetitions): 24367224 ns (371 ns/op)
github.com/willf/bloom 'Test' 65536 items (10 repetitions): 21881142 ns (333 ns/op)
github.com/dataence/bloom/standard 'Add' 65536 items (10 repetitions): 23041644 ns (351 ns/op)
github.com/dataence/bloom/standard 'Check' 65536 items (10 repetitions): 19153133 ns (292 ns/op)
github.com/cabello/bloom 'Add' 65536 items (10 repetitions): 131921507 ns (2012 ns/op)
github.com/cabello/bloom 'Contains' 65536 items (10 repetitions): 131108962 ns (2000 ns/op)
(on MBPro15 OSX10.8.5 i7 4Core 2.4Ghz)
With 32bit bloom filters (bloom32) using modified sdbm, bloom32 does hashing with only 2 bit shifts, one xor and one substraction per byte. smdb is about as fast as fnv64a but gives less collisions with the dataset (see mask above). bloom.New(float64(10 * 1<<16),float64(7)) populated with 1<<16 random items from the dataset (see above) and tested against the rest results in less than 0.05% collisions.

View File

@ -1,284 +0,0 @@
// The MIT License (MIT)
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// 2019/08/25 code revision to reduce unsafe use
// Parts are adopted from the fork at ipfs/bbloom after performance rev by
// Steve Allen (https://github.com/Stebalien)
// (see https://github.com/ipfs/bbloom/blob/master/bbloom.go)
// -> func Has
// -> func set
// -> func add
package bbloom
import (
"bytes"
"encoding/json"
"log"
"math"
"sync"
"unsafe"
)
// helper
// not needed anymore by Set
// var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128}
func getSize(ui64 uint64) (size uint64, exponent uint64) {
if ui64 < uint64(512) {
ui64 = uint64(512)
}
size = uint64(1)
for size < ui64 {
size <<= 1
exponent++
}
return size, exponent
}
func calcSizeByWrongPositives(numEntries, wrongs float64) (uint64, uint64) {
size := -1 * numEntries * math.Log(wrongs) / math.Pow(float64(0.69314718056), 2)
locs := math.Ceil(float64(0.69314718056) * size / numEntries)
return uint64(size), uint64(locs)
}
// New
// returns a new bloomfilter
func New(params ...float64) (bloomfilter Bloom) {
var entries, locs uint64
if len(params) == 2 {
if params[1] < 1 {
entries, locs = calcSizeByWrongPositives(params[0], params[1])
} else {
entries, locs = uint64(params[0]), uint64(params[1])
}
} else {
log.Fatal("usage: New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(3)) or New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(0.03))")
}
size, exponent := getSize(uint64(entries))
bloomfilter = Bloom{
Mtx: &sync.Mutex{},
sizeExp: exponent,
size: size - 1,
setLocs: locs,
shift: 64 - exponent,
}
bloomfilter.Size(size)
return bloomfilter
}
// NewWithBoolset
// takes a []byte slice and number of locs per entry
// returns the bloomfilter with a bitset populated according to the input []byte
func NewWithBoolset(bs *[]byte, locs uint64) (bloomfilter Bloom) {
bloomfilter = New(float64(len(*bs)<<3), float64(locs))
for i, b := range *bs {
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bloomfilter.bitset[0])) + uintptr(i))) = b
}
return bloomfilter
}
// bloomJSONImExport
// Im/Export structure used by JSONMarshal / JSONUnmarshal
type bloomJSONImExport struct {
FilterSet []byte
SetLocs uint64
}
// JSONUnmarshal
// takes JSON-Object (type bloomJSONImExport) as []bytes
// returns Bloom object
func JSONUnmarshal(dbData []byte) Bloom {
bloomImEx := bloomJSONImExport{}
json.Unmarshal(dbData, &bloomImEx)
buf := bytes.NewBuffer(bloomImEx.FilterSet)
bs := buf.Bytes()
bf := NewWithBoolset(&bs, bloomImEx.SetLocs)
return bf
}
//
// Bloom filter
type Bloom struct {
Mtx *sync.Mutex
ElemNum uint64
bitset []uint64
sizeExp uint64
size uint64
setLocs uint64
shift uint64
}
// <--- http://www.cse.yorku.ca/~oz/hash.html
// modified Berkeley DB Hash (32bit)
// hash is casted to l, h = 16bit fragments
// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) {
// hash := uint64(len(*b))
// for _, c := range *b {
// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash
// }
// h = hash >> bl.shift
// l = hash << bl.shift >> bl.shift
// return l, h
// }
// Update: found sipHash of Jean-Philippe Aumasson & Daniel J. Bernstein to be even faster than absdbm()
// https://131002.net/siphash/
// siphash was implemented for Go by Dmitry Chestnykh https://github.com/dchest/siphash
// Add
// set the bit(s) for entry; Adds an entry to the Bloom filter
func (bl *Bloom) Add(entry []byte) {
l, h := bl.sipHash(entry)
for i := uint64(0); i < bl.setLocs; i++ {
bl.set((h + i*l) & bl.size)
bl.ElemNum++
}
}
// AddTS
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
func (bl *Bloom) AddTS(entry []byte) {
bl.Mtx.Lock()
defer bl.Mtx.Unlock()
bl.Add(entry)
}
// Has
// check if bit(s) for entry is/are set
// returns true if the entry was added to the Bloom Filter
func (bl Bloom) Has(entry []byte) bool {
l, h := bl.sipHash(entry)
res := true
for i := uint64(0); i < bl.setLocs; i++ {
res = res && bl.isSet((h+i*l)&bl.size)
// https://github.com/ipfs/bbloom/commit/84e8303a9bfb37b2658b85982921d15bbb0fecff
// // Branching here (early escape) is not worth it
// // This is my conclusion from benchmarks
// // (prevents loop unrolling)
// switch bl.IsSet((h + i*l) & bl.size) {
// case false:
// return false
// }
}
return res
}
// HasTS
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
func (bl *Bloom) HasTS(entry []byte) bool {
bl.Mtx.Lock()
defer bl.Mtx.Unlock()
return bl.Has(entry)
}
// AddIfNotHas
// Only Add entry if it's not present in the bloomfilter
// returns true if entry was added
// returns false if entry was allready registered in the bloomfilter
func (bl Bloom) AddIfNotHas(entry []byte) (added bool) {
if bl.Has(entry) {
return added
}
bl.Add(entry)
return true
}
// AddIfNotHasTS
// Tread safe: Only Add entry if it's not present in the bloomfilter
// returns true if entry was added
// returns false if entry was allready registered in the bloomfilter
func (bl *Bloom) AddIfNotHasTS(entry []byte) (added bool) {
bl.Mtx.Lock()
defer bl.Mtx.Unlock()
return bl.AddIfNotHas(entry)
}
// Size
// make Bloom filter with as bitset of size sz
func (bl *Bloom) Size(sz uint64) {
bl.bitset = make([]uint64, sz>>6)
}
// Clear
// resets the Bloom filter
func (bl *Bloom) Clear() {
bs := bl.bitset
for i := range bs {
bs[i] = 0
}
}
// Set
// set the bit[idx] of bitsit
func (bl *Bloom) set(idx uint64) {
// ommit unsafe
// *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3))) |= mask[idx%8]
bl.bitset[idx>>6] |= 1 << (idx % 64)
}
// IsSet
// check if bit[idx] of bitset is set
// returns true/false
func (bl *Bloom) isSet(idx uint64) bool {
// ommit unsafe
// return (((*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)))) >> (idx % 8)) & 1) == 1
return bl.bitset[idx>>6]&(1<<(idx%64)) != 0
}
// JSONMarshal
// returns JSON-object (type bloomJSONImExport) as []byte
func (bl Bloom) JSONMarshal() []byte {
bloomImEx := bloomJSONImExport{}
bloomImEx.SetLocs = uint64(bl.setLocs)
bloomImEx.FilterSet = make([]byte, len(bl.bitset)<<3)
for i := range bloomImEx.FilterSet {
bloomImEx.FilterSet[i] = *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[0])) + uintptr(i)))
}
data, err := json.Marshal(bloomImEx)
if err != nil {
log.Fatal("json.Marshal failed: ", err)
}
return data
}
// // alternative hashFn
// func (bl Bloom) fnv64a(b *[]byte) (l, h uint64) {
// h64 := fnv.New64a()
// h64.Write(*b)
// hash := h64.Sum64()
// h = hash >> 32
// l = hash << 32 >> 32
// return l, h
// }
//
// // <-- http://partow.net/programming/hashfunctions/index.html
// // citation: An algorithm proposed by Donald E. Knuth in The Art Of Computer Programming Volume 3,
// // under the topic of sorting and search chapter 6.4.
// // modified to fit with boolset-length
// func (bl Bloom) DEKHash(b *[]byte) (l, h uint64) {
// hash := uint64(len(*b))
// for _, c := range *b {
// hash = ((hash << 5) ^ (hash >> bl.shift)) ^ uint64(c)
// }
// h = hash >> bl.shift
// l = hash << bl.sizeExp >> bl.sizeExp
// return l, h
// }

View File

@ -1,225 +0,0 @@
// Written in 2012 by Dmitry Chestnykh.
//
// To the extent possible under law, the author have dedicated all copyright
// and related and neighboring rights to this software to the public domain
// worldwide. This software is distributed without any warranty.
// http://creativecommons.org/publicdomain/zero/1.0/
//
// Package siphash implements SipHash-2-4, a fast short-input PRF
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
package bbloom
// Hash returns the 64-bit SipHash-2-4 of the given byte slice with two 64-bit
// parts of 128-bit key: k0 and k1.
func (bl Bloom) sipHash(p []byte) (l, h uint64) {
// Initialization.
v0 := uint64(8317987320269560794) // k0 ^ 0x736f6d6570736575
v1 := uint64(7237128889637516672) // k1 ^ 0x646f72616e646f6d
v2 := uint64(7816392314733513934) // k0 ^ 0x6c7967656e657261
v3 := uint64(8387220255325274014) // k1 ^ 0x7465646279746573
t := uint64(len(p)) << 56
// Compression.
for len(p) >= 8 {
m := uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
v3 ^= m
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
v0 ^= m
p = p[8:]
}
// Compress last block.
switch len(p) {
case 7:
t |= uint64(p[6]) << 48
fallthrough
case 6:
t |= uint64(p[5]) << 40
fallthrough
case 5:
t |= uint64(p[4]) << 32
fallthrough
case 4:
t |= uint64(p[3]) << 24
fallthrough
case 3:
t |= uint64(p[2]) << 16
fallthrough
case 2:
t |= uint64(p[1]) << 8
fallthrough
case 1:
t |= uint64(p[0])
}
v3 ^= t
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
v0 ^= t
// Finalization.
v2 ^= 0xff
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 3.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 4.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// return v0 ^ v1 ^ v2 ^ v3
hash := v0 ^ v1 ^ v2 ^ v3
h = hash >> bl.shift
l = hash << bl.shift >> bl.shift
return l, h
}

View File

@ -1,24 +0,0 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org/>

View File

@ -1,7 +0,0 @@
# gopher-json [![GoDoc](https://godoc.org/layeh.com/gopher-json?status.svg)](https://godoc.org/layeh.com/gopher-json)
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
## License
Public domain

View File

@ -1,33 +0,0 @@
// Package json is a simple JSON encoder/decoder for gopher-lua.
//
// Documentation
//
// The following functions are exposed by the library:
// decode(string): Decodes a JSON string. Returns nil and an error string if
// the string could not be decoded.
// encode(value): Encodes a value into a JSON string. Returns nil and an error
// string if the value could not be encoded.
//
// The following types are supported:
//
// Lua | JSON
// ---------+-----
// nil | null
// number | number
// string | string
// table | object: when table is non-empty and has only string keys
// | array: when table is empty, or has only sequential numeric keys
// | starting from 1
//
// Attempting to encode any other Lua type will result in an error.
//
// Example
//
// Below is an example usage of the library:
// import (
// luajson "layeh.com/gopher-json"
// )
//
// L := lua.NewState()
// luajson.Preload(s)
package json

View File

@ -1,189 +0,0 @@
package json
import (
"encoding/json"
"errors"
"github.com/yuin/gopher-lua"
)
// Preload adds json to the given Lua state's package.preload table. After it
// has been preloaded, it can be loaded using require:
//
// local json = require("json")
func Preload(L *lua.LState) {
L.PreloadModule("json", Loader)
}
// Loader is the module loader function.
func Loader(L *lua.LState) int {
t := L.NewTable()
L.SetFuncs(t, api)
L.Push(t)
return 1
}
var api = map[string]lua.LGFunction{
"decode": apiDecode,
"encode": apiEncode,
}
func apiDecode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to decode"), 1)
return 0
}
str := L.CheckString(1)
value, err := Decode(L, []byte(str))
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(value)
return 1
}
func apiEncode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to encode"), 1)
return 0
}
value := L.CheckAny(1)
data, err := Encode(value)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(lua.LString(string(data)))
return 1
}
var (
errNested = errors.New("cannot encode recursively nested tables to JSON")
errSparseArray = errors.New("cannot encode sparse array")
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
)
type invalidTypeError lua.LValueType
func (i invalidTypeError) Error() string {
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
}
// Encode returns the JSON encoding of value.
func Encode(value lua.LValue) ([]byte, error) {
return json.Marshal(jsonValue{
LValue: value,
visited: make(map[*lua.LTable]bool),
})
}
type jsonValue struct {
lua.LValue
visited map[*lua.LTable]bool
}
func (j jsonValue) MarshalJSON() (data []byte, err error) {
switch converted := j.LValue.(type) {
case lua.LBool:
data, err = json.Marshal(bool(converted))
case lua.LNumber:
data, err = json.Marshal(float64(converted))
case *lua.LNilType:
data = []byte(`null`)
case lua.LString:
data, err = json.Marshal(string(converted))
case *lua.LTable:
if j.visited[converted] {
return nil, errNested
}
j.visited[converted] = true
key, value := converted.Next(lua.LNil)
switch key.Type() {
case lua.LTNil: // empty table
data = []byte(`[]`)
case lua.LTNumber:
arr := make([]jsonValue, 0, converted.Len())
expectedKey := lua.LNumber(1)
for key != lua.LNil {
if key.Type() != lua.LTNumber {
err = errInvalidKeys
return
}
if expectedKey != key {
err = errSparseArray
return
}
arr = append(arr, jsonValue{value, j.visited})
expectedKey++
key, value = converted.Next(key)
}
data, err = json.Marshal(arr)
case lua.LTString:
obj := make(map[string]jsonValue)
for key != lua.LNil {
if key.Type() != lua.LTString {
err = errInvalidKeys
return
}
obj[key.String()] = jsonValue{value, j.visited}
key, value = converted.Next(key)
}
data, err = json.Marshal(obj)
default:
err = errInvalidKeys
}
default:
err = invalidTypeError(j.LValue.Type())
}
return
}
// Decode converts the JSON encoded data to Lua values.
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
var value interface{}
err := json.Unmarshal(data, &value)
if err != nil {
return nil, err
}
return DecodeValue(L, value), nil
}
// DecodeValue converts the value to a Lua value.
//
// This function only converts values that the encoding/json package decodes to.
// All other values will return lua.LNil.
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
switch converted := value.(type) {
case bool:
return lua.LBool(converted)
case float64:
return lua.LNumber(converted)
case string:
return lua.LString(converted)
case json.Number:
return lua.LString(converted)
case []interface{}:
arr := L.CreateTable(len(converted), 0)
for _, item := range converted {
arr.Append(DecodeValue(L, item))
}
return arr
case map[string]interface{}:
tbl := L.CreateTable(0, len(converted))
for key, item := range converted {
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
}
return tbl
case nil:
return lua.LNil
}
return lua.LNil
}

View File

@ -1,6 +0,0 @@
/integration/redis_src/
/integration/dump.rdb
*.swp
/integration/nodes.conf
.idea/
miniredis.iml

View File

@ -1,225 +0,0 @@
## Changelog
### v2.23.0
- basic INFO support (thanks @kirill-a-belov)
- support COUNT in SSCAN (thanks @Abdi-dd)
- test and support Go 1.19
- support LPOS (thanks @ianstarz)
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
### v2.22.0
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
- fix invalid resposne of COMMAND (thanks @zsh1995)
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
### v2.21.0
- support for GETEX (thanks @dntj)
- support for GT and LT in ZADD (thanks @lsgndln)
- support for XAUTOCLAIM (thanks @randall-fulton)
### v2.20.0
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
### v2.19.0
- support for TYPE in SCAN (thanks @0xDiddi)
- update BITPOS (thanks @dirkm)
- fix a lua redis.call() return value (thanks @mpetronic)
- update ZRANGE (thanks @valdemarpereira)
### v2.18.0
- support for ZUNION (thanks @propan)
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
- support for LMOVE (thanks @btwear)
### v2.17.0
- added miniredis.RunT(t)
### v2.16.1
- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
- fix exclusive ranges in XRANGE (thanks @joseotoro)
### v2.16.0
- simplify some code (thanks @zonque)
- support for EXAT/PXAT in SET
- support for XTRIM (thanks @joseotoro)
- support for ZRANDMEMBER
- support for redis.log() in lua (thanks @dirkm)
### v2.15.2
- Fix race condition in blocking code (thanks @zonque and @robx)
- XREAD accepts '$' as ID (thanks @bradengroom)
### v2.15.1
- EVAL should cache the script (thanks @guoshimin)
### v2.15.0
- target redis 6.2 and added new args to various commands
- support for all hyperlog commands (thanks @ilbaktin)
- support for GETDEL (thanks @wszaranski)
### v2.14.5
- added XPENDING
- support for BLOCK option in XREAD and XREADGROUP
### v2.14.4
- fix BITPOS error (thanks @xiaoyuzdy)
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
- fix empty EXEC return type (thanks @ashanbrown)
- fix XDEL (thanks @svakili and @yvesf)
- fix FLUSHALL for streams (thanks @svakili)
### v2.14.3
- fix problem where Lua code didn't set the selected DB
- update to redis 6.0.10 (thanks @lazappa)
### v2.14.2
- update LUA dependency
- deal with (p)unsubscribe when there are no channels
### v2.14.1
- mod tidy
### v2.14.0
- support for HELLO and the RESP3 protocol
- KEEPTTL in SET (thanks @johnpena)
### v2.13.3
- support Go 1.14 and 1.15
- update the `Check...()` methods
- support for XREAD (thanks @pieterlexis)
### v2.13.2
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
- changed unit and integration tests to compare raw payloads, not parsed payloads
- remove "redigo" dependency
### v2.13.1
- added HSTRLEN
- minimal support for ACL users in AUTH
### v2.13.0
- added RunTLS(...)
- added SetError(...)
### v2.12.0
- redis 6
- Lua json update (thanks @gsmith85)
- CLUSTER commands (thanks @kratisto)
- fix TOUCH
- fix a shutdown race condition
### v2.11.4
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
### v2.11.3
- support for TOUCH (thanks @cleroux)
- support for cluster and stream commands (thanks @kak-tus)
### v2.11.2
- make sure Lua code is executed concurrently
- add command GEORADIUSBYMEMBER (thanks @kyeett)
### v2.11.1
- globals protection for Lua code (thanks @vk-outreach)
- HSET update (thanks @carlgreen)
- fix BLPOP block on shutdown (thanks @Asalle)
### v2.11.0
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
- added GEODIST
- improved precision for geohashes, closer to what real redis does
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
### v2.10.1
- added m.Server()
### v2.10.0
- added UNLINK
- fix DEL zero-argument case
- cleanup some direct access commands
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
### v2.9.1
- fix issue with ZRANGEBYLEX
- fix issue with BRPOPLPUSH and direct access
### v2.9.0
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
- fix messages generated by PSUBSCRIBE
- optional internal seed (thanks @zikaeroh)
### v2.8.0
Proper `v2` in go.mod.
### older
See https://github.com/alicebob/miniredis/releases for the full changelog

View File

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2014 Harmen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,12 +0,0 @@
.PHONY: all test testrace int
all: test
test:
go test ./...
testrace:
go test -race ./...
int:
${MAKE} -C integration all

View File

@ -1,333 +0,0 @@
# Miniredis
Pure Go Redis test server, used in Go unittests.
##
Sometimes you want to test code which uses Redis, without making it a full-blown
integration test.
Miniredis implements (parts of) the Redis server, to be used in unittests. It
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
It saves you from using mock code, and since the redis server lives in the
test process you can query for values directly, without going through the server
stack.
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
Be sure to import v2:
```
import "github.com/alicebob/miniredis/v2"
```
## Commands
Implemented commands:
- Connection (complete)
- AUTH -- see RequireAuth()
- ECHO
- HELLO -- see RequireUserAuth()
- PING
- SELECT
- SWAPDB
- QUIT
- Key
- COPY
- DEL
- EXISTS
- EXPIRE
- EXPIREAT
- KEYS
- MOVE
- PERSIST
- PEXPIRE
- PEXPIREAT
- PTTL
- RENAME
- RENAMENX
- RANDOMKEY -- see m.Seed(...)
- SCAN
- TOUCH
- TTL
- TYPE
- UNLINK
- Transactions (complete)
- DISCARD
- EXEC
- MULTI
- UNWATCH
- WATCH
- Server
- DBSIZE
- FLUSHALL
- FLUSHDB
- TIME -- returns time.Now() or value set by SetTime()
- COMMAND -- partly
- INFO -- partly, returns only "clients" section with one field "connected_clients"
- String keys (complete)
- APPEND
- BITCOUNT
- BITOP
- BITPOS
- DECR
- DECRBY
- GET
- GETBIT
- GETRANGE
- GETSET
- GETDEL
- GETEX
- INCR
- INCRBY
- INCRBYFLOAT
- MGET
- MSET
- MSETNX
- PSETEX
- SET
- SETBIT
- SETEX
- SETNX
- SETRANGE
- STRLEN
- Hash keys (complete)
- HDEL
- HEXISTS
- HGET
- HGETALL
- HINCRBY
- HINCRBYFLOAT
- HKEYS
- HLEN
- HMGET
- HMSET
- HSET
- HSETNX
- HSTRLEN
- HVALS
- HSCAN
- List keys (complete)
- BLPOP
- BRPOP
- BRPOPLPUSH
- LINDEX
- LINSERT
- LLEN
- LPOP
- LPUSH
- LPUSHX
- LRANGE
- LREM
- LSET
- LTRIM
- RPOP
- RPOPLPUSH
- RPUSH
- RPUSHX
- LMOVE
- Pub/Sub (complete)
- PSUBSCRIBE
- PUBLISH
- PUBSUB
- PUNSUBSCRIBE
- SUBSCRIBE
- UNSUBSCRIBE
- Set keys (complete)
- SADD
- SCARD
- SDIFF
- SDIFFSTORE
- SINTER
- SINTERSTORE
- SISMEMBER
- SMEMBERS
- SMOVE
- SPOP -- see m.Seed(...)
- SRANDMEMBER -- see m.Seed(...)
- SREM
- SUNION
- SUNIONSTORE
- SSCAN
- Sorted Set keys (complete)
- ZADD
- ZCARD
- ZCOUNT
- ZINCRBY
- ZINTERSTORE
- ZLEXCOUNT
- ZPOPMIN
- ZPOPMAX
- ZRANDMEMBER
- ZRANGE
- ZRANGEBYLEX
- ZRANGEBYSCORE
- ZRANK
- ZREM
- ZREMRANGEBYLEX
- ZREMRANGEBYRANK
- ZREMRANGEBYSCORE
- ZREVRANGE
- ZREVRANGEBYLEX
- ZREVRANGEBYSCORE
- ZREVRANK
- ZSCORE
- ZUNION
- ZUNIONSTORE
- ZSCAN
- Stream keys
- XACK
- XADD
- XAUTOCLAIM
- XCLAIM
- XDEL
- XGROUP CREATE
- XGROUP CREATECONSUMER
- XGROUP DESTROY
- XGROUP DELCONSUMER
- XINFO STREAM -- partly
- XINFO GROUPS
- XINFO CONSUMERS -- partly
- XLEN
- XRANGE
- XREAD
- XREADGROUP
- XREVRANGE
- XPENDING
- XTRIM
- Scripting
- EVAL
- EVALSHA
- SCRIPT LOAD
- SCRIPT EXISTS
- SCRIPT FLUSH
- GEO
- GEOADD
- GEODIST
- ~~GEOHASH~~
- GEOPOS
- GEORADIUS
- GEORADIUS_RO
- GEORADIUSBYMEMBER
- GEORADIUSBYMEMBER_RO
- Cluster
- CLUSTER SLOTS
- CLUSTER KEYSLOT
- CLUSTER NODES
- HyperLogLog (complete)
- PFADD
- PFCOUNT
- PFMERGE
## TTLs, key expiration, and time
Since miniredis is intended to be used in unittests TTLs don't decrease
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
key. It will return 0 when no TTL is set.
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
0 will be removed.
EXPIREAT and PEXPIREAT values will be
converted to a duration. For that you can either set m.SetTime(t) to use that
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
which case time.Now() will be used.
SetTime() also sets the value returned by TIME, which defaults to time.Now().
It is not updated by FastForward, only by SetTime.
## Randomness and Seed()
Miniredis will use `math/rand`'s global RNG for randomness unless a seed is
provided by calling `m.Seed(...)`. If a seed is provided, then miniredis will
use its own RNG based on that seed.
Commands which use randomness are: RANDOMKEY, SPOP, and SRANDMEMBER.
## Example
``` Go
import (
...
"github.com/alicebob/miniredis/v2"
...
)
func TestSomething(t *testing.T) {
s := miniredis.RunT(t)
// Optionally set some keys your code expects:
s.Set("foo", "bar")
s.HSet("some", "other", "key")
// Run your code and see if it behaves.
// An example using the redigo library from "github.com/gomodule/redigo/redis":
c, err := redis.Dial("tcp", s.Addr())
_, err = c.Do("SET", "foo", "bar")
// Optionally check values in redis...
if got, err := s.Get("foo"); err != nil || got != "bar" {
t.Error("'foo' has the wrong value")
}
// ... or use a helper for that:
s.CheckGet(t, "foo", "bar")
// TTL and expiration:
s.Set("foo", "bar")
s.SetTTL("foo", 10*time.Second)
s.FastForward(11 * time.Second)
if s.Exists("foo") {
t.Fatal("'foo' should not have existed anymore")
}
}
```
## Not supported
Commands which will probably not be implemented:
- CLUSTER (all)
- ~~CLUSTER *~~
- ~~READONLY~~
- ~~READWRITE~~
- Key
- ~~DUMP~~
- ~~MIGRATE~~
- ~~OBJECT~~
- ~~RESTORE~~
- ~~WAIT~~
- Scripting
- ~~SCRIPT DEBUG~~
- ~~SCRIPT KILL~~
- Server
- ~~BGSAVE~~
- ~~BGWRITEAOF~~
- ~~CLIENT *~~
- ~~CONFIG *~~
- ~~DEBUG *~~
- ~~LASTSAVE~~
- ~~MONITOR~~
- ~~ROLE~~
- ~~SAVE~~
- ~~SHUTDOWN~~
- ~~SLAVEOF~~
- ~~SLOWLOG~~
- ~~SYNC~~
## &c.
Integration tests are run against Redis 6.2.6. The [./integration](./integration/) subdir
compares miniredis against a real redis instance.
The Redis 6 RESP3 protocol is supported. If there are problems, please open
an issue.
If you want to test Redis Sentinel have a look at [minisentinel](https://github.com/Bose/minisentinel).
A changelog is kept at [CHANGELOG.md](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md).
[![Go Reference](https://pkg.go.dev/badge/github.com/alicebob/miniredis/v2.svg)](https://pkg.go.dev/github.com/alicebob/miniredis/v2)

View File

@ -1,63 +0,0 @@
package miniredis
import (
"reflect"
"sort"
)
// T is implemented by Testing.T
type T interface {
Helper()
Errorf(string, ...interface{})
}
// CheckGet does not call Errorf() iff there is a string key with the
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
func (m *Miniredis) CheckGet(t T, key, expected string) {
t.Helper()
found, err := m.Get(key)
if err != nil {
t.Errorf("GET error, key %#v: %v", key, err)
return
}
if found != expected {
t.Errorf("GET error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
// CheckList does not call Errorf() iff there is a list key with the
// expected values.
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
t.Helper()
found, err := m.List(key)
if err != nil {
t.Errorf("List error, key %#v: %v", key, err)
return
}
if !reflect.DeepEqual(expected, found) {
t.Errorf("List error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
// CheckSet does not call Errorf() iff there is a set key with the
// expected values.
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
t.Helper()
found, err := m.Members(key)
if err != nil {
t.Errorf("Set error, key %#v: %v", key, err)
return
}
sort.Strings(expected)
if !reflect.DeepEqual(expected, found) {
t.Errorf("Set error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}

View File

@ -1,67 +0,0 @@
// Commands from https://redis.io/commands#cluster
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsCluster handles some cluster operations.
func commandsCluster(m *Miniredis) {
m.srv.Register("CLUSTER", m.cmdCluster)
}
func (m *Miniredis) cmdCluster(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
switch strings.ToUpper(args[0]) {
case "SLOTS":
m.cmdClusterSlots(c, cmd, args)
case "KEYSLOT":
m.cmdClusterKeySlot(c, cmd, args)
case "NODES":
m.cmdClusterNodes(c, cmd, args)
default:
setDirty(c)
c.WriteError(fmt.Sprintf("ERR 'CLUSTER %s' not supported", strings.Join(args, " ")))
return
}
}
// CLUSTER SLOTS
func (m *Miniredis) cmdClusterSlots(c *server.Peer, cmd string, args []string) {
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteLen(1)
c.WriteLen(3)
c.WriteInt(0)
c.WriteInt(16383)
c.WriteLen(3)
c.WriteBulk(m.srv.Addr().IP.String())
c.WriteInt(m.srv.Addr().Port)
c.WriteBulk("09dbe9720cda62f7865eabc5fd8857c5d2678366")
})
}
// CLUSTER KEYSLOT
func (m *Miniredis) cmdClusterKeySlot(c *server.Peer, cmd string, args []string) {
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteInt(163)
})
}
// CLUSTER NODES
func (m *Miniredis) cmdClusterNodes(c *server.Peer, cmd string, args []string) {
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteBulk("e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca 127.0.0.1:7000@7000 myself,master - 0 0 1 connected 0-16383")
})
}

File diff suppressed because it is too large Load Diff

View File

@ -1,284 +0,0 @@
// Commands from https://redis.io/commands#connection
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
func commandsConnection(m *Miniredis) {
m.srv.Register("AUTH", m.cmdAuth)
m.srv.Register("ECHO", m.cmdEcho)
m.srv.Register("HELLO", m.cmdHello)
m.srv.Register("PING", m.cmdPing)
m.srv.Register("QUIT", m.cmdQuit)
m.srv.Register("SELECT", m.cmdSelect)
m.srv.Register("SWAPDB", m.cmdSwapdb)
}
// PING
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if len(args) > 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
payload := ""
if len(args) > 0 {
payload = args[0]
}
// PING is allowed in subscribed state
if sub := getCtx(c).subscriber; sub != nil {
c.Block(func(c *server.Writer) {
c.WriteLen(2)
c.WriteBulk("pong")
c.WriteBulk(payload)
})
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if payload == "" {
c.WriteInline("PONG")
return
}
c.WriteBulk(payload)
})
}
// AUTH
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if len(args) > 2 {
c.WriteError(msgSyntaxError)
return
}
if m.checkPubsub(c, cmd) {
return
}
if getCtx(c).nested {
c.WriteError(msgNotFromScripts)
return
}
var opts = struct {
username string
password string
}{
username: "default",
password: args[0],
}
if len(args) == 2 {
opts.username, opts.password = args[0], args[1]
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if len(m.passwords) == 0 && opts.username == "default" {
c.WriteError("ERR AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
return
}
setPW, ok := m.passwords[opts.username]
if !ok {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
if setPW != opts.password {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
ctx.authenticated = true
c.WriteOK()
})
}
// HELLO
func (m *Miniredis) cmdHello(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
c.WriteError(errWrongNumber(cmd))
return
}
var opts struct {
version int
username string
password string
}
if ok := optIntErr(c, args[0], &opts.version, "ERR Protocol version is not an integer or out of range"); !ok {
return
}
args = args[1:]
switch opts.version {
case 2, 3:
default:
c.WriteError("NOPROTO unsupported protocol version")
return
}
var checkAuth bool
for len(args) > 0 {
switch strings.ToUpper(args[0]) {
case "AUTH":
if len(args) < 3 {
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
return
}
opts.username, opts.password, args = args[1], args[2], args[3:]
checkAuth = true
case "SETNAME":
if len(args) < 2 {
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
return
}
_, args = args[1], args[2:]
default:
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
return
}
}
if len(m.passwords) == 0 && opts.username == "default" {
// redis ignores legacy "AUTH" if it's not enabled.
checkAuth = false
}
if checkAuth {
setPW, ok := m.passwords[opts.username]
if !ok {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
if setPW != opts.password {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
getCtx(c).authenticated = true
}
c.Resp3 = opts.version == 3
c.WriteMapLen(7)
c.WriteBulk("server")
c.WriteBulk("miniredis")
c.WriteBulk("version")
c.WriteBulk("6.0.5")
c.WriteBulk("proto")
c.WriteInt(opts.version)
c.WriteBulk("id")
c.WriteInt(42)
c.WriteBulk("mode")
c.WriteBulk("standalone")
c.WriteBulk("role")
c.WriteBulk("master")
c.WriteBulk("modules")
c.WriteLen(0)
}
// ECHO
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
msg := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteBulk(msg)
})
}
// SELECT
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.isValidCMD(c, cmd) {
return
}
var opts struct {
id int
}
if ok := optInt(c, args[0], &opts.id); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if opts.id < 0 {
c.WriteError(msgDBIndexOutOfRange)
setDirty(c)
return
}
ctx.selectedDB = opts.id
c.WriteOK()
})
}
// SWAPDB
func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
var opts struct {
id1 int
id2 int
}
if ok := optIntErr(c, args[0], &opts.id1, "ERR invalid first DB index"); !ok {
return
}
if ok := optIntErr(c, args[1], &opts.id2, "ERR invalid second DB index"); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if opts.id1 < 0 || opts.id2 < 0 {
c.WriteError(msgDBIndexOutOfRange)
setDirty(c)
return
}
m.swapDB(opts.id1, opts.id2)
c.WriteOK()
})
}
// QUIT
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
// QUIT isn't transactionfied and accepts any arguments.
c.WriteOK()
c.Close()
}

View File

@ -1,669 +0,0 @@
// Commands from https://redis.io/commands#generic
package miniredis
import (
"sort"
"strconv"
"strings"
"time"
"github.com/alicebob/miniredis/v2/server"
)
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
func commandsGeneric(m *Miniredis) {
m.srv.Register("COPY", m.cmdCopy)
m.srv.Register("DEL", m.cmdDel)
// DUMP
m.srv.Register("EXISTS", m.cmdExists)
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
m.srv.Register("KEYS", m.cmdKeys)
// MIGRATE
m.srv.Register("MOVE", m.cmdMove)
// OBJECT
m.srv.Register("PERSIST", m.cmdPersist)
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
m.srv.Register("PTTL", m.cmdPTTL)
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
m.srv.Register("RENAME", m.cmdRename)
m.srv.Register("RENAMENX", m.cmdRenamenx)
// RESTORE
m.srv.Register("TOUCH", m.cmdTouch)
m.srv.Register("TTL", m.cmdTTL)
m.srv.Register("TYPE", m.cmdType)
m.srv.Register("SCAN", m.cmdScan)
// SORT
m.srv.Register("UNLINK", m.cmdDel)
}
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
// converted to a duration.
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
return func(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
value int
}
opts.key = args[0]
if ok := optInt(c, args[1], &opts.value); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// Key must be present.
if _, ok := db.keys[opts.key]; !ok {
c.WriteInt(0)
return
}
if unix {
db.ttl[opts.key] = m.at(opts.value, d)
} else {
db.ttl[opts.key] = time.Duration(opts.value) * d
}
db.keyVersion[opts.key]++
db.checkTTL(opts.key)
c.WriteInt(1)
})
}
}
// TOUCH
func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count := 0
for _, key := range args {
if db.exists(key) {
count++
}
}
c.WriteInt(count)
})
}
// TTL
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// No such key
c.WriteInt(-2)
return
}
v, ok := db.ttl[key]
if !ok {
// no expire value
c.WriteInt(-1)
return
}
c.WriteInt(int(v.Seconds()))
})
}
// PTTL
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// no such key
c.WriteInt(-2)
return
}
v, ok := db.ttl[key]
if !ok {
// no expire value
c.WriteInt(-1)
return
}
c.WriteInt(int(v.Nanoseconds() / 1000000))
})
}
// PERSIST
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// no such key
c.WriteInt(0)
return
}
if _, ok := db.ttl[key]; !ok {
// no expire value
c.WriteInt(0)
return
}
delete(db.ttl, key)
db.keyVersion[key]++
c.WriteInt(1)
})
}
// DEL and UNLINK
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count := 0
for _, key := range args {
if db.exists(key) {
count++
}
db.del(key, true) // delete expire
}
c.WriteInt(count)
})
}
// TYPE
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError("usage error")
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInline("none")
return
}
c.WriteInline(t)
})
}
// EXISTS
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
found := 0
for _, k := range args {
if db.exists(k) {
found++
}
}
c.WriteInt(found)
})
}
// MOVE
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
targetDB int
}
opts.key = args[0]
opts.targetDB, _ = strconv.Atoi(args[1])
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if ctx.selectedDB == opts.targetDB {
c.WriteError("ERR source and destination objects are the same")
return
}
db := m.db(ctx.selectedDB)
targetDB := m.db(opts.targetDB)
if !db.move(opts.key, targetDB) {
c.WriteInt(0)
return
}
c.WriteInt(1)
})
}
// KEYS
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
keys, _ := matchKeys(db.allKeys(), key)
c.WriteLen(len(keys))
for _, s := range keys {
c.WriteBulk(s)
}
})
}
// RANDOMKEY
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if len(db.keys) == 0 {
c.WriteNull()
return
}
nr := m.randIntn(len(db.keys))
for k := range db.keys {
if nr == 0 {
c.WriteBulk(k)
return
}
nr--
}
})
}
// RENAME
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
from string
to string
}{
from: args[0],
to: args[1],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.from) {
c.WriteError(msgKeyNotFound)
return
}
db.rename(opts.from, opts.to)
c.WriteOK()
})
}
// RENAMENX
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
from string
to string
}{
from: args[0],
to: args[1],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.from) {
c.WriteError(msgKeyNotFound)
return
}
if db.exists(opts.to) {
c.WriteInt(0)
return
}
db.rename(opts.from, opts.to)
c.WriteInt(1)
})
}
// SCAN
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
cursor int
withMatch bool
match string
withType bool
_type string
}
if ok := optIntErr(c, args[0], &opts.cursor, msgInvalidCursor); !ok {
return
}
args = args[1:]
// MATCH, COUNT and TYPE options
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
// we do nothing with count
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if _, err := strconv.Atoi(args[1]); err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.withMatch = true
opts.match, args = args[1], args[2:]
continue
}
if strings.ToLower(args[0]) == "type" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.withType = true
opts._type, args = strings.ToLower(args[1]), args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// We return _all_ (matched) keys every time.
if opts.cursor != 0 {
// Invalid cursor.
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
var keys []string
if opts.withType {
keys = make([]string, 0)
for k, t := range db.keys {
// type must be given exactly; no pattern matching is performed
if t == opts._type {
keys = append(keys, k)
}
}
sort.Strings(keys) // To make things deterministic.
} else {
keys = db.allKeys()
}
if opts.withMatch {
keys, _ = matchKeys(keys, opts.match)
}
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(len(keys))
for _, k := range keys {
c.WriteBulk(k)
}
})
}
// COPY
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts = struct {
from string
to string
destinationDB int
replace bool
}{
destinationDB: -1,
}
opts.from, opts.to, args = args[0], args[1], args[2:]
for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "db":
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
db, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if db < 0 {
setDirty(c)
c.WriteError(msgDBIndexOutOfRange)
return
}
opts.destinationDB = db
args = args[2:]
case "replace":
opts.replace = true
args = args[1:]
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
fromDB, toDB := ctx.selectedDB, opts.destinationDB
if toDB == -1 {
toDB = fromDB
}
if fromDB == toDB && opts.from == opts.to {
c.WriteError("ERR source and destination objects are the same")
return
}
if !m.db(fromDB).exists(opts.from) {
c.WriteInt(0)
return
}
if !opts.replace {
if m.db(toDB).exists(opts.to) {
c.WriteInt(0)
return
}
}
m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to)
c.WriteInt(1)
})
}

View File

@ -1,609 +0,0 @@
// Commands from https://redis.io/commands#geo
package miniredis
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsGeo handles GEOADD, GEORADIUS etc.
func commandsGeo(m *Miniredis) {
m.srv.Register("GEOADD", m.cmdGeoadd)
m.srv.Register("GEODIST", m.cmdGeodist)
m.srv.Register("GEOPOS", m.cmdGeopos)
m.srv.Register("GEORADIUS", m.cmdGeoradius)
m.srv.Register("GEORADIUS_RO", m.cmdGeoradius)
m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember)
m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember)
}
// GEOADD
func (m *Miniredis) cmdGeoadd(c *server.Peer, cmd string, args []string) {
if len(args) < 3 || len(args[1:])%3 != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != "zset" {
c.WriteError(ErrWrongType.Error())
return
}
toSet := map[string]float64{}
for len(args) > 2 {
rawLong, rawLat, name := args[0], args[1], args[2]
args = args[3:]
longitude, err := strconv.ParseFloat(rawLong, 64)
if err != nil {
c.WriteError("ERR value is not a valid float")
return
}
latitude, err := strconv.ParseFloat(rawLat, 64)
if err != nil {
c.WriteError("ERR value is not a valid float")
return
}
if latitude < -85.05112878 ||
latitude > 85.05112878 ||
longitude < -180 ||
longitude > 180 {
c.WriteError(fmt.Sprintf("ERR invalid longitude,latitude pair %.6f,%.6f", longitude, latitude))
return
}
toSet[name] = float64(toGeohash(longitude, latitude))
}
set := 0
for name, score := range toSet {
if db.ssetAdd(key, score, name) {
set++
}
}
c.WriteInt(set)
})
}
// GEODIST
func (m *Miniredis) cmdGeodist(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, from, to, args := args[0], args[1], args[2], args[3:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteNull()
return
}
if db.t(key) != "zset" {
c.WriteError(ErrWrongType.Error())
return
}
unit := "m"
if len(args) > 0 {
unit, args = args[0], args[1:]
}
if len(args) > 0 {
c.WriteError(msgSyntaxError)
return
}
toMeter := parseUnit(unit)
if toMeter == 0 {
c.WriteError(msgUnsupportedUnit)
return
}
members := db.sortedsetKeys[key]
fromD, okFrom := members.get(from)
toD, okTo := members.get(to)
if !okFrom || !okTo {
c.WriteNull()
return
}
fromLo, fromLat := fromGeohash(uint64(fromD))
toLo, toLat := fromGeohash(uint64(toD))
dist := distance(fromLat, fromLo, toLat, toLo) / toMeter
c.WriteBulk(fmt.Sprintf("%.4f", dist))
})
}
// GEOPOS
func (m *Miniredis) cmdGeopos(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != "zset" {
c.WriteError(ErrWrongType.Error())
return
}
c.WriteLen(len(args))
for _, l := range args {
if !db.ssetExists(key, l) {
c.WriteLen(-1)
continue
}
score := db.ssetScore(key, l)
c.WriteLen(2)
long, lat := fromGeohash(uint64(score))
c.WriteBulk(fmt.Sprintf("%f", long))
c.WriteBulk(fmt.Sprintf("%f", lat))
}
})
}
type geoDistance struct {
Name string
Score float64
Distance float64
Longitude float64
Latitude float64
}
// GEORADIUS and GEORADIUS_RO
func (m *Miniredis) cmdGeoradius(c *server.Peer, cmd string, args []string) {
if len(args) < 5 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
longitude, err := strconv.ParseFloat(args[1], 64)
if err != nil {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
latitude, err := strconv.ParseFloat(args[2], 64)
if err != nil {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
radius, err := strconv.ParseFloat(args[3], 64)
if err != nil || radius < 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
toMeter := parseUnit(args[4])
if toMeter == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
args = args[5:]
var opts struct {
withDist bool
withCoord bool
direction direction // unsorted
count int
withStore bool
storeKey string
withStoredist bool
storedistKey string
}
for len(args) > 0 {
arg := args[0]
args = args[1:]
switch strings.ToUpper(arg) {
case "WITHCOORD":
opts.withCoord = true
case "WITHDIST":
opts.withDist = true
case "ASC":
opts.direction = asc
case "DESC":
opts.direction = desc
case "COUNT":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
n, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if n <= 0 {
setDirty(c)
c.WriteError("ERR COUNT must be > 0")
return
}
args = args[1:]
opts.count = n
case "STORE":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStore = true
opts.storeKey = args[0]
args = args[1:]
case "STOREDIST":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStoredist = true
opts.storedistKey = args[0]
args = args[1:]
default:
setDirty(c)
c.WriteError("ERR syntax error")
return
}
}
if strings.ToUpper(cmd) == "GEORADIUS_RO" && (opts.withStore || opts.withStoredist) {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
return
}
db := m.db(ctx.selectedDB)
members := db.ssetElements(key)
matches := withinRadius(members, longitude, latitude, radius*toMeter)
// deal with ASC/DESC
if opts.direction != unsorted {
sort.Slice(matches, func(i, j int) bool {
if opts.direction == desc {
return matches[i].Distance > matches[j].Distance
}
return matches[i].Distance < matches[j].Distance
})
}
// deal with COUNT
if opts.count > 0 && len(matches) > opts.count {
matches = matches[:opts.count]
}
// deal with "STORE x"
if opts.withStore {
db.del(opts.storeKey, true)
for _, member := range matches {
db.ssetAdd(opts.storeKey, member.Score, member.Name)
}
c.WriteInt(len(matches))
return
}
// deal with "STOREDIST x"
if opts.withStoredist {
db.del(opts.storedistKey, true)
for _, member := range matches {
db.ssetAdd(opts.storedistKey, member.Distance/toMeter, member.Name)
}
c.WriteInt(len(matches))
return
}
c.WriteLen(len(matches))
for _, member := range matches {
if !opts.withDist && !opts.withCoord {
c.WriteBulk(member.Name)
continue
}
len := 1
if opts.withDist {
len++
}
if opts.withCoord {
len++
}
c.WriteLen(len)
c.WriteBulk(member.Name)
if opts.withDist {
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter))
}
if opts.withCoord {
c.WriteLen(2)
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
}
}
})
}
// GEORADIUSBYMEMBER and GEORADIUSBYMEMBER_RO
func (m *Miniredis) cmdGeoradiusbymember(c *server.Peer, cmd string, args []string) {
if len(args) < 4 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
member string
radius float64
toMeter float64
withDist bool
withCoord bool
direction direction // unsorted
count int
withStore bool
storeKey string
withStoredist bool
storedistKey string
}{
key: args[0],
member: args[1],
}
r, err := strconv.ParseFloat(args[2], 64)
if err != nil || r < 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
opts.radius = r
opts.toMeter = parseUnit(args[3])
if opts.toMeter == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
args = args[4:]
for len(args) > 0 {
arg := args[0]
args = args[1:]
switch strings.ToUpper(arg) {
case "WITHCOORD":
opts.withCoord = true
case "WITHDIST":
opts.withDist = true
case "ASC":
opts.direction = asc
case "DESC":
opts.direction = desc
case "COUNT":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
n, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if n <= 0 {
setDirty(c)
c.WriteError("ERR COUNT must be > 0")
return
}
args = args[1:]
opts.count = n
case "STORE":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStore = true
opts.storeKey = args[0]
args = args[1:]
case "STOREDIST":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStoredist = true
opts.storedistKey = args[0]
args = args[1:]
default:
setDirty(c)
c.WriteError("ERR syntax error")
return
}
}
if strings.ToUpper(cmd) == "GEORADIUSBYMEMBER_RO" && (opts.withStore || opts.withStoredist) {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
return
}
db := m.db(ctx.selectedDB)
if !db.exists(opts.key) {
c.WriteNull()
return
}
if db.t(opts.key) != "zset" {
c.WriteError(ErrWrongType.Error())
return
}
// get position of member
if !db.ssetExists(opts.key, opts.member) {
c.WriteError("ERR could not decode requested zset member")
return
}
score := db.ssetScore(opts.key, opts.member)
longitude, latitude := fromGeohash(uint64(score))
members := db.ssetElements(opts.key)
matches := withinRadius(members, longitude, latitude, opts.radius*opts.toMeter)
// deal with ASC/DESC
if opts.direction != unsorted {
sort.Slice(matches, func(i, j int) bool {
if opts.direction == desc {
return matches[i].Distance > matches[j].Distance
}
return matches[i].Distance < matches[j].Distance
})
}
// deal with COUNT
if opts.count > 0 && len(matches) > opts.count {
matches = matches[:opts.count]
}
// deal with "STORE x"
if opts.withStore {
db.del(opts.storeKey, true)
for _, member := range matches {
db.ssetAdd(opts.storeKey, member.Score, member.Name)
}
c.WriteInt(len(matches))
return
}
// deal with "STOREDIST x"
if opts.withStoredist {
db.del(opts.storedistKey, true)
for _, member := range matches {
db.ssetAdd(opts.storedistKey, member.Distance/opts.toMeter, member.Name)
}
c.WriteInt(len(matches))
return
}
c.WriteLen(len(matches))
for _, member := range matches {
if !opts.withDist && !opts.withCoord {
c.WriteBulk(member.Name)
continue
}
len := 1
if opts.withDist {
len++
}
if opts.withCoord {
len++
}
c.WriteLen(len)
c.WriteBulk(member.Name)
if opts.withDist {
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/opts.toMeter))
}
if opts.withCoord {
c.WriteLen(2)
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
}
}
})
}
func withinRadius(members []ssElem, longitude, latitude, radius float64) []geoDistance {
matches := []geoDistance{}
for _, el := range members {
elLo, elLat := fromGeohash(uint64(el.score))
distanceInMeter := distance(latitude, longitude, elLat, elLo)
if distanceInMeter <= radius {
matches = append(matches, geoDistance{
Name: el.member,
Score: el.score,
Distance: distanceInMeter,
Longitude: elLo,
Latitude: elLat,
})
}
}
return matches
}
func parseUnit(u string) float64 {
switch u {
case "m":
return 1
case "km":
return 1000
case "mi":
return 1609.34
case "ft":
return 0.3048
default:
return 0
}
}

View File

@ -1,683 +0,0 @@
// Commands from https://redis.io/commands#hash
package miniredis
import (
"math/big"
"strconv"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsHash handles all hash value operations.
func commandsHash(m *Miniredis) {
m.srv.Register("HDEL", m.cmdHdel)
m.srv.Register("HEXISTS", m.cmdHexists)
m.srv.Register("HGET", m.cmdHget)
m.srv.Register("HGETALL", m.cmdHgetall)
m.srv.Register("HINCRBY", m.cmdHincrby)
m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat)
m.srv.Register("HKEYS", m.cmdHkeys)
m.srv.Register("HLEN", m.cmdHlen)
m.srv.Register("HMGET", m.cmdHmget)
m.srv.Register("HMSET", m.cmdHmset)
m.srv.Register("HSET", m.cmdHset)
m.srv.Register("HSETNX", m.cmdHsetnx)
m.srv.Register("HSTRLEN", m.cmdHstrlen)
m.srv.Register("HVALS", m.cmdHvals)
m.srv.Register("HSCAN", m.cmdHscan)
}
// HSET
func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, pairs := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if len(pairs)%2 == 1 {
c.WriteError(errWrongNumber(cmd))
return
}
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
new := db.hashSet(key, pairs...)
c.WriteInt(new)
})
}
// HSETNX
func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
value string
}{
key: args[0],
field: args[1],
value: args[2],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
if _, ok := db.hashKeys[opts.key]; !ok {
db.hashKeys[opts.key] = map[string]string{}
db.keys[opts.key] = "hash"
}
_, ok := db.hashKeys[opts.key][opts.field]
if ok {
c.WriteInt(0)
return
}
db.hashKeys[opts.key][opts.field] = opts.value
db.keyVersion[opts.key]++
c.WriteInt(1)
})
}
// HMSET
func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
if len(args)%2 != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
for len(args) > 0 {
field, value := args[0], args[1]
args = args[2:]
db.hashSet(key, field, value)
}
c.WriteOK()
})
}
// HGET
func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, field := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteNull()
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
value, ok := db.hashKeys[key][field]
if !ok {
c.WriteNull()
return
}
c.WriteBulk(value)
})
}
// HDEL
func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
fields []string
}{
key: args[0],
fields: args[1:],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[opts.key]
if !ok {
// No key is zero deleted
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
deleted := 0
for _, f := range opts.fields {
_, ok := db.hashKeys[opts.key][f]
if !ok {
continue
}
delete(db.hashKeys[opts.key], f)
deleted++
}
c.WriteInt(deleted)
// Nothing left. Remove the whole key.
if len(db.hashKeys[opts.key]) == 0 {
db.del(opts.key, true)
}
})
}
// HEXISTS
func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
}{
key: args[0],
field: args[1],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[opts.key]
if !ok {
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
if _, ok := db.hashKeys[opts.key][opts.field]; !ok {
c.WriteInt(0)
return
}
c.WriteInt(1)
})
}
// HGETALL
func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteMapLen(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
c.WriteMapLen(len(db.hashKeys[key]))
for _, k := range db.hashFields(key) {
c.WriteBulk(k)
c.WriteBulk(db.hashGet(key, k))
}
})
}
// HKEYS
func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteLen(0)
return
}
if db.t(key) != "hash" {
c.WriteError(msgWrongType)
return
}
fields := db.hashFields(key)
c.WriteLen(len(fields))
for _, f := range fields {
c.WriteBulk(f)
}
})
}
// HSTRLEN
func (m *Miniredis) cmdHstrlen(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
hash, key := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[hash]
if !ok {
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
keys := db.hashKeys[hash]
c.WriteInt(len(keys[key]))
})
}
// HVALS
func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteLen(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
vals := db.hashValues(key)
c.WriteLen(len(vals))
for _, v := range vals {
c.WriteBulk(v)
}
})
}
// HLEN
func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInt(0)
return
}
if t != "hash" {
c.WriteError(msgWrongType)
return
}
c.WriteInt(len(db.hashKeys[key]))
})
}
// HMGET
func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
f, ok := db.hashKeys[key]
if !ok {
f = map[string]string{}
}
c.WriteLen(len(args) - 1)
for _, k := range args[1:] {
v, ok := f[k]
if !ok {
c.WriteNull()
continue
}
c.WriteBulk(v)
}
})
}
// HINCRBY
func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
delta int
}{
key: args[0],
field: args[1],
}
if ok := optInt(c, args[2], &opts.delta); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
v, err := db.hashIncr(opts.key, opts.field, opts.delta)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteInt(v)
})
}
// HINCRBYFLOAT
func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
delta *big.Float
}{
key: args[0],
field: args[1],
}
delta, _, err := big.ParseFloat(args[2], 10, 128, 0)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidFloat)
return
}
opts.delta = delta
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != "hash" {
c.WriteError(msgWrongType)
return
}
v, err := db.hashIncrfloat(opts.key, opts.field, opts.delta)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteBulk(formatBig(v))
})
}
// HSCAN
func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
cursor int
withMatch bool
match string
}{
key: args[0],
}
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
return
}
args = args[2:]
// MATCH and COUNT options
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
// we do nothing with count
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
_, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.withMatch = true
opts.match, args = args[1], args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// return _all_ (matched) keys every time
if opts.cursor != 0 {
// Invalid cursor.
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
if db.exists(opts.key) && db.t(opts.key) != "hash" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.hashFields(opts.key)
if opts.withMatch {
members, _ = matchKeys(members, opts.match)
}
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
// HSCAN gives key, values.
c.WriteLen(len(members) * 2)
for _, k := range members {
c.WriteBulk(k)
c.WriteBulk(db.hashGet(opts.key, k))
}
})
}

View File

@ -1,95 +0,0 @@
package miniredis
import "github.com/alicebob/miniredis/v2/server"
// commandsHll handles all hll related operations.
func commandsHll(m *Miniredis) {
m.srv.Register("PFADD", m.cmdPfadd)
m.srv.Register("PFCOUNT", m.cmdPfcount)
m.srv.Register("PFMERGE", m.cmdPfmerge)
}
// PFADD
func (m *Miniredis) cmdPfadd(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, items := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != "hll" {
c.WriteError(ErrNotValidHllValue.Error())
return
}
altered := db.hllAdd(key, items...)
c.WriteInt(altered)
})
}
// PFCOUNT
func (m *Miniredis) cmdPfcount(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count, err := db.hllCount(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteInt(count)
})
}
// PFMERGE
func (m *Miniredis) cmdPfmerge(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if err := db.hllMerge(keys); err != nil {
c.WriteError(err.Error())
return
}
c.WriteOK()
})
}

View File

@ -1,40 +0,0 @@
package miniredis
import (
"fmt"
"github.com/alicebob/miniredis/v2/server"
)
// Command 'INFO' from https://redis.io/commands/info/
func (m *Miniredis) cmdInfo(c *server.Peer, cmd string, args []string) {
if !m.isValidCMD(c, cmd) {
return
}
if len(args) > 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
const (
clientsSectionName = "clients"
clientsSectionContent = "# Clients\nconnected_clients:%d\r\n"
)
var result string
for _, key := range args {
if key != clientsSectionName {
setDirty(c)
c.WriteError(fmt.Sprintf("section (%s) is not supported", key))
return
}
}
result = fmt.Sprintf(clientsSectionContent, m.Server().ClientsLen())
c.WriteBulk(result)
})
}

View File

@ -1,986 +0,0 @@
// Commands from https://redis.io/commands#list
package miniredis
import (
"strconv"
"strings"
"time"
"github.com/alicebob/miniredis/v2/server"
)
type leftright int
const (
left leftright = iota
right
)
// commandsList handles list commands (mostly L*)
func commandsList(m *Miniredis) {
m.srv.Register("BLPOP", m.cmdBlpop)
m.srv.Register("BRPOP", m.cmdBrpop)
m.srv.Register("BRPOPLPUSH", m.cmdBrpoplpush)
m.srv.Register("LINDEX", m.cmdLindex)
m.srv.Register("LPOS", m.cmdLpos)
m.srv.Register("LINSERT", m.cmdLinsert)
m.srv.Register("LLEN", m.cmdLlen)
m.srv.Register("LPOP", m.cmdLpop)
m.srv.Register("LPUSH", m.cmdLpush)
m.srv.Register("LPUSHX", m.cmdLpushx)
m.srv.Register("LRANGE", m.cmdLrange)
m.srv.Register("LREM", m.cmdLrem)
m.srv.Register("LSET", m.cmdLset)
m.srv.Register("LTRIM", m.cmdLtrim)
m.srv.Register("RPOP", m.cmdRpop)
m.srv.Register("RPOPLPUSH", m.cmdRpoplpush)
m.srv.Register("RPUSH", m.cmdRpush)
m.srv.Register("RPUSHX", m.cmdRpushx)
m.srv.Register("LMOVE", m.cmdLmove)
}
// BLPOP
func (m *Miniredis) cmdBlpop(c *server.Peer, cmd string, args []string) {
m.cmdBXpop(c, cmd, args, left)
}
// BRPOP
func (m *Miniredis) cmdBrpop(c *server.Peer, cmd string, args []string) {
m.cmdBXpop(c, cmd, args, right)
}
func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
timeoutS := args[len(args)-1]
keys := args[:len(args)-1]
timeout, err := strconv.Atoi(timeoutS)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidTimeout)
return
}
if timeout < 0 {
setDirty(c)
c.WriteError(msgNegTimeout)
return
}
blocking(
m,
c,
time.Duration(timeout)*time.Second,
func(c *server.Peer, ctx *connCtx) bool {
db := m.db(ctx.selectedDB)
for _, key := range keys {
if !db.exists(key) {
continue
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return true
}
if len(db.listKeys[key]) == 0 {
continue
}
c.WriteLen(2)
c.WriteBulk(key)
var v string
switch lr {
case left:
v = db.listLpop(key)
case right:
v = db.listPop(key)
}
c.WriteBulk(v)
return true
}
return false
},
func(c *server.Peer) {
// timeout
c.WriteLen(-1)
},
)
}
// LINDEX
func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, offsets := args[0], args[1]
offset, err := strconv.Atoi(offsets)
if err != nil || offsets == "-0" {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No such key
c.WriteNull()
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
if offset < 0 {
offset = len(l) + offset
}
if offset < 0 || offset > len(l)-1 {
c.WriteNull()
return
}
c.WriteBulk(l[offset])
})
}
// LPOS key element [RANK rank] [COUNT num-matches] [MAXLEN len]
func (m *Miniredis) cmdLpos(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
if len(args) == 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
// Extract options from arguments if present.
//
// Redis allows duplicate options and uses the last specified.
// `LPOS key term RANK 1 RANK 2` is effectively the same as
// `LPOS key term RANK 2`
if len(args)%2 == 1 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
rank, count := 1, 1 // Default values
var maxlen int // Default value is the list length (see below)
var countSpecified, maxlenSpecified bool
if len(args) > 2 {
for i := 2; i < len(args); i++ {
if i%2 == 0 {
val := args[i+1]
var err error
switch strings.ToLower(args[i]) {
case "rank":
if rank, err = strconv.Atoi(val); err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if rank == 0 {
setDirty(c)
c.WriteError(msgRankIsZero)
return
}
case "count":
countSpecified = true
if count, err = strconv.Atoi(val); err != nil || count < 0 {
setDirty(c)
c.WriteError(msgCountIsNegative)
return
}
case "maxlen":
maxlenSpecified = true
if maxlen, err = strconv.Atoi(val); err != nil || maxlen < 0 {
setDirty(c)
c.WriteError(msgMaxLengthIsNegative)
return
}
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
}
}
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
key, element := args[0], args[1]
t, ok := db.keys[key]
if !ok {
// No such key
c.WriteNull()
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
// RANK cannot be zero (see above).
// If RANK is positive search forward (left to right).
// If RANK is negative search backward (right to left).
// Iterator returns true to continue iterating.
iterate := func(iterator func(i int, e string) bool) {
comparisons := len(l)
// Only use max length if specified, not zero, and less than total length.
// When max length is specified, but is zero, this means "unlimited".
if maxlenSpecified && maxlen != 0 && maxlen < len(l) {
comparisons = maxlen
}
if rank > 0 {
for i := 0; i < comparisons; i++ {
if resume := iterator(i, l[i]); !resume {
return
}
}
} else if rank < 0 {
start := len(l) - 1
end := len(l) - comparisons
for i := start; i >= end; i-- {
if resume := iterator(i, l[i]); !resume {
return
}
}
}
}
var currentRank, currentCount int
vals := make([]int, 0, count)
iterate(func(i int, e string) bool {
if e == element {
currentRank++
// Only collect values only after surpassing the absolute value of rank.
if rank > 0 && currentRank < rank {
return true
}
if rank < 0 && currentRank < -rank {
return true
}
vals = append(vals, i)
currentCount++
if currentCount == count {
return false
}
}
return true
})
if !countSpecified && len(vals) == 0 {
c.WriteNull()
return
}
if !countSpecified && len(vals) == 1 {
c.WriteInt(vals[0])
return
}
c.WriteLen(len(vals))
for _, val := range vals {
c.WriteInt(val)
}
})
}
// LINSERT
func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) {
if len(args) != 4 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
where := 0
switch strings.ToLower(args[1]) {
case "before":
where = -1
case "after":
where = +1
default:
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
pivot := args[2]
value := args[3]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No such key
c.WriteInt(0)
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[key]
for i, el := range l {
if el != pivot {
continue
}
if where < 0 {
l = append(l[:i], append(listKey{value}, l[i:]...)...)
} else {
if i == len(l)-1 {
l = append(l, value)
} else {
l = append(l[:i+1], append(listKey{value}, l[i+1:]...)...)
}
}
db.listKeys[key] = l
db.keyVersion[key]++
c.WriteInt(len(l))
return
}
c.WriteInt(-1)
})
}
// LLEN
func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
// No such key. That's zero length.
c.WriteInt(0)
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
c.WriteInt(len(db.listKeys[key]))
})
}
// LPOP
func (m *Miniredis) cmdLpop(c *server.Peer, cmd string, args []string) {
m.cmdXpop(c, cmd, args, left)
}
// RPOP
func (m *Miniredis) cmdRpop(c *server.Peer, cmd string, args []string) {
m.cmdXpop(c, cmd, args, right)
}
func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
withCount bool
count int
}
opts.key, args = args[0], args[1:]
if len(args) > 0 {
if ok := optInt(c, args[0], &opts.count); !ok {
return
}
if opts.count < 0 {
setDirty(c)
c.WriteError(msgOutOfRange)
return
}
opts.withCount = true
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.key) {
// non-existing key is fine
c.WriteNull()
return
}
if db.t(opts.key) != "list" {
c.WriteError(msgWrongType)
return
}
if opts.withCount {
var popped []string
for opts.count > 0 && len(db.listKeys[opts.key]) > 0 {
switch lr {
case left:
popped = append(popped, db.listLpop(opts.key))
case right:
popped = append(popped, db.listPop(opts.key))
}
opts.count -= 1
}
if len(popped) == 0 {
c.WriteLen(-1)
} else {
c.WriteStrings(popped)
}
return
}
var elem string
switch lr {
case left:
elem = db.listLpop(opts.key)
case right:
elem = db.listPop(opts.key)
}
c.WriteBulk(elem)
})
}
// LPUSH
func (m *Miniredis) cmdLpush(c *server.Peer, cmd string, args []string) {
m.cmdXpush(c, cmd, args, left)
}
// RPUSH
func (m *Miniredis) cmdRpush(c *server.Peer, cmd string, args []string) {
m.cmdXpush(c, cmd, args, right)
}
func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
var newLen int
for _, value := range args {
switch lr {
case left:
newLen = db.listLpush(key, value)
case right:
newLen = db.listPush(key, value)
}
}
c.WriteInt(newLen)
})
}
// LPUSHX
func (m *Miniredis) cmdLpushx(c *server.Peer, cmd string, args []string) {
m.cmdXpushx(c, cmd, args, left)
}
// RPUSHX
func (m *Miniredis) cmdRpushx(c *server.Peer, cmd string, args []string) {
m.cmdXpushx(c, cmd, args, right)
}
func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr leftright) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "list" {
c.WriteError(msgWrongType)
return
}
var newLen int
for _, value := range args {
switch lr {
case left:
newLen = db.listLpush(key, value)
case right:
newLen = db.listPush(key, value)
}
}
c.WriteInt(newLen)
})
}
// LRANGE
func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
start int
end int
}{
key: args[0],
}
if ok := optInt(c, args[1], &opts.start); !ok {
return
}
if ok := optInt(c, args[2], &opts.end); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[opts.key]
if len(l) == 0 {
c.WriteLen(0)
return
}
rs, re := redisRange(len(l), opts.start, opts.end, false)
c.WriteLen(re - rs)
for _, el := range l[rs:re] {
c.WriteBulk(el)
}
})
}
// LREM
func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
count int
value string
}
opts.key = args[0]
if ok := optInt(c, args[1], &opts.count); !ok {
return
}
opts.value = args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.key) {
c.WriteInt(0)
return
}
if db.t(opts.key) != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[opts.key]
if opts.count < 0 {
reverseSlice(l)
}
deleted := 0
newL := []string{}
toDelete := len(l)
if opts.count < 0 {
toDelete = -opts.count
}
if opts.count > 0 {
toDelete = opts.count
}
for _, el := range l {
if el == opts.value {
if toDelete > 0 {
deleted++
toDelete--
continue
}
}
newL = append(newL, el)
}
if opts.count < 0 {
reverseSlice(newL)
}
if len(newL) == 0 {
db.del(opts.key, true)
} else {
db.listKeys[opts.key] = newL
db.keyVersion[opts.key]++
}
c.WriteInt(deleted)
})
}
// LSET
func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
index int
value string
}
opts.key = args[0]
if ok := optInt(c, args[1], &opts.index); !ok {
return
}
opts.value = args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.key) {
c.WriteError(msgKeyNotFound)
return
}
if db.t(opts.key) != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[opts.key]
index := opts.index
if index < 0 {
index = len(l) + index
}
if index < 0 || index > len(l)-1 {
c.WriteError(msgOutOfRange)
return
}
l[index] = opts.value
db.keyVersion[opts.key]++
c.WriteOK()
})
}
// LTRIM
func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
start int
end int
}
opts.key = args[0]
if ok := optInt(c, args[1], &opts.start); !ok {
return
}
if ok := optInt(c, args[2], &opts.end); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[opts.key]
if !ok {
c.WriteOK()
return
}
if t != "list" {
c.WriteError(msgWrongType)
return
}
l := db.listKeys[opts.key]
rs, re := redisRange(len(l), opts.start, opts.end, false)
l = l[rs:re]
if len(l) == 0 {
db.del(opts.key, true)
} else {
db.listKeys[opts.key] = l
db.keyVersion[opts.key]++
}
c.WriteOK()
})
}
// RPOPLPUSH
func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
src, dst := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(src) {
c.WriteNull()
return
}
if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") {
c.WriteError(msgWrongType)
return
}
elem := db.listPop(src)
db.listLpush(dst, elem)
c.WriteBulk(elem)
})
}
// BRPOPLPUSH
func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
src string
dst string
timeout int
}
opts.src = args[0]
opts.dst = args[1]
if ok := optIntErr(c, args[2], &opts.timeout, msgInvalidTimeout); !ok {
return
}
if opts.timeout < 0 {
setDirty(c)
c.WriteError(msgNegTimeout)
return
}
blocking(
m,
c,
time.Duration(opts.timeout)*time.Second,
func(c *server.Peer, ctx *connCtx) bool {
db := m.db(ctx.selectedDB)
if !db.exists(opts.src) {
return false
}
if db.t(opts.src) != "list" || (db.exists(opts.dst) && db.t(opts.dst) != "list") {
c.WriteError(msgWrongType)
return true
}
if len(db.listKeys[opts.src]) == 0 {
return false
}
elem := db.listPop(opts.src)
db.listLpush(opts.dst, elem)
c.WriteBulk(elem)
return true
},
func(c *server.Peer) {
// timeout
c.WriteLen(-1)
},
)
}
// LMOVE
func (m *Miniredis) cmdLmove(c *server.Peer, cmd string, args []string) {
if len(args) != 4 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
src string
dst string
srcDir string
dstDir string
}{
src: args[0],
dst: args[1],
srcDir: strings.ToLower(args[2]),
dstDir: strings.ToLower(args[3]),
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.src) {
c.WriteNull()
return
}
if db.t(opts.src) != "list" || (db.exists(opts.dst) && db.t(opts.dst) != "list") {
c.WriteError(msgWrongType)
return
}
var elem string
switch opts.srcDir {
case "left":
elem = db.listLpop(opts.src)
case "right":
elem = db.listPop(opts.src)
default:
c.WriteError(msgSyntaxError)
return
}
switch opts.dstDir {
case "left":
db.listLpush(opts.dst, elem)
case "right":
db.listPush(opts.dst, elem)
default:
c.WriteError(msgSyntaxError)
return
}
c.WriteBulk(elem)
})
}

Some files were not shown because too many files have changed in this diff Show More