refactor: mqtt broker mochi import packages + connack packet custom
This commit is contained in:
parent
1a1d6abcc1
commit
51acf7569f
|
|
@ -1,43 +0,0 @@
|
|||
name: build
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
if: success()
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Calc coverage
|
||||
run: |
|
||||
go test -v -covermode=count -coverprofile=coverage.out ./...
|
||||
- name: Convert coverage.out to coverage.lcov
|
||||
uses: jandelgado/gcov2lcov-action@v1.0.6
|
||||
- name: Coveralls
|
||||
uses: coverallsapp/github-action@v1.1.2
|
||||
with:
|
||||
github-token: ${{ secrets.github_token }}
|
||||
path-to-lcov: coverage.lcov
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
linters:
|
||||
disable-all: false
|
||||
fix: false # Fix found issues (if it's supported by the linter).
|
||||
enable:
|
||||
# - asasalint
|
||||
# - asciicheck
|
||||
# - bidichk
|
||||
# - bodyclose
|
||||
# - containedctx
|
||||
# - contextcheck
|
||||
#- cyclop
|
||||
# - deadcode
|
||||
- decorder
|
||||
# - depguard
|
||||
# - dogsled
|
||||
# - dupl
|
||||
- durationcheck
|
||||
# - errchkjson
|
||||
# - errname
|
||||
- errorlint
|
||||
# - execinquery
|
||||
# - exhaustive
|
||||
# - exhaustruct
|
||||
# - exportloopref
|
||||
#- forcetypeassert
|
||||
#- forbidigo
|
||||
#- funlen
|
||||
#- gci
|
||||
# - gochecknoglobals
|
||||
# - gochecknoinits
|
||||
# - gocognit
|
||||
# - goconst
|
||||
# - gocritic
|
||||
- gocyclo
|
||||
- godot
|
||||
# - godox
|
||||
# - goerr113
|
||||
# - gofmt
|
||||
# - gofumpt
|
||||
# - goheader
|
||||
- goimports
|
||||
# - golint
|
||||
# - gomnd
|
||||
# - gomoddirectives
|
||||
# - gomodguard
|
||||
# - goprintffuncname
|
||||
- gosec
|
||||
- gosimple
|
||||
- govet
|
||||
# - grouper
|
||||
# - ifshort
|
||||
- importas
|
||||
- ineffassign
|
||||
# - interfacebloat
|
||||
# - interfacer
|
||||
# - ireturn
|
||||
# - lll
|
||||
# - maintidx
|
||||
# - makezero
|
||||
- maligned
|
||||
- misspell
|
||||
# - nakedret
|
||||
# - nestif
|
||||
# - nilerr
|
||||
# - nilnil
|
||||
# - nlreturn
|
||||
# - noctx
|
||||
# - nolintlint
|
||||
# - nonamedreturns
|
||||
# - nosnakecase
|
||||
# - nosprintfhostport
|
||||
# - paralleltest
|
||||
# - prealloc
|
||||
# - predeclared
|
||||
# - promlinter
|
||||
- reassign
|
||||
# - revive
|
||||
# - rowserrcheck
|
||||
# - scopelint
|
||||
# - sqlclosecheck
|
||||
# - staticcheck
|
||||
# - structcheck
|
||||
# - stylecheck
|
||||
# - tagliatelle
|
||||
# - tenv
|
||||
# - testpackage
|
||||
# - thelper
|
||||
- tparallel
|
||||
# - typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usestdlibvars
|
||||
# - varcheck
|
||||
# - varnamelen
|
||||
- wastedassign
|
||||
- whitespace
|
||||
# - wrapcheck
|
||||
# - wsl
|
||||
disable:
|
||||
- errcheck
|
||||
|
||||
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
FROM golang:1.19.0-alpine3.15 AS builder
|
||||
|
||||
RUN apk update
|
||||
RUN apk add git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod ./
|
||||
COPY go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . ./
|
||||
|
||||
RUN go build -o /app/mochi ./cmd
|
||||
|
||||
|
||||
FROM alpine
|
||||
|
||||
WORKDIR /
|
||||
COPY --from=builder /app/mochi .
|
||||
|
||||
# tcp
|
||||
EXPOSE 1883
|
||||
|
||||
# websockets
|
||||
EXPOSE 1882
|
||||
|
||||
# dashboard
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT [ "/mochi" ]
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
This broker is an implementation of mochi. I've to forked it to customize CONNACK packet userProperty, although mochi lib might have a better approach to do it.
|
||||
|
||||
To run this project you might have Go compiler in your machine, and inside cmd folder there is a run.sh script, which runs the project with the right arguments; also inside the same folder is the auth.json file, that carries configs of RBAC.
|
||||
|
||||

|
||||
|
|
@ -1,568 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
|
||||
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
|
||||
)
|
||||
|
||||
// ReadFn is the function signature for the function used for reading and processing new packets.
|
||||
type ReadFn func(*Client, packets.Packet) error
|
||||
|
||||
// Clients contains a map of the clients known by the broker.
|
||||
type Clients struct {
|
||||
internal map[string]*Client // clients known by the broker, keyed on client id.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClients returns an instance of Clients.
|
||||
func NewClients() *Clients {
|
||||
return &Clients{
|
||||
internal: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new client to the clients map, keyed on client id.
|
||||
func (cl *Clients) Add(val *Client) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
cl.internal[val.ID] = val
|
||||
}
|
||||
|
||||
// GetAll returns all the clients.
|
||||
func (cl *Clients) GetAll() map[string]*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
m := map[string]*Client{}
|
||||
for k, v := range cl.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns the value of a client if it exists.
|
||||
func (cl *Clients) Get(id string) (*Client, bool) {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
val, ok := cl.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the clients map.
|
||||
func (cl *Clients) Len() int {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
val := len(cl.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a client from the internal map.
|
||||
func (cl *Clients) Delete(id string) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
delete(cl.internal, id)
|
||||
}
|
||||
|
||||
// GetByListener returns clients matching a listener id.
|
||||
func (cl *Clients) GetByListener(id string) []*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
for _, client := range cl.internal {
|
||||
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
return clients
|
||||
}
|
||||
|
||||
// Client contains information about a client known by the broker.
|
||||
type Client struct {
|
||||
Properties ClientProperties // client properties
|
||||
State ClientState // the operational state of the client.
|
||||
Net ClientConnection // network connection state of the clinet
|
||||
ID string // the client id.
|
||||
ops *ops // ops provides a reference to server ops.
|
||||
sync.RWMutex // mutex
|
||||
}
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
Inline bool // client is an inline programmetic client
|
||||
}
|
||||
|
||||
// ClientProperties contains the properties which define the client behaviour.
|
||||
type ClientProperties struct {
|
||||
Props packets.Properties
|
||||
Will Will
|
||||
Username []byte
|
||||
ProtocolVersion byte
|
||||
Clean bool
|
||||
}
|
||||
|
||||
// Will contains the last will and testament details for a client connection.
|
||||
type Will struct {
|
||||
Payload []byte // -
|
||||
User []packets.UserProperty // -
|
||||
TopicName string // -
|
||||
Flag uint32 // 0,1
|
||||
WillDelayInterval uint32 // -
|
||||
Qos byte // -
|
||||
Retain bool // -
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
type ClientState struct {
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
outbound chan *packets.Packet // queue for pending outbound packets
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
done uint32 // atomic counter which indicates that the client has closed
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
}
|
||||
|
||||
// newClient returns a new instance of Client. This is almost exclusively used by Server
|
||||
// for creating new clients, but it lives here because it's not dependent.
|
||||
func newClient(c net.Conn, o *ops) *Client {
|
||||
cl := &Client{
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||
keepalive: defaultKeepalive,
|
||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
ops: o,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
Conn: c,
|
||||
bconn: bufio.NewReadWriter(
|
||||
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
|
||||
),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
|
||||
func (cl *Client) WriteLoop() {
|
||||
for pk := range cl.State.outbound {
|
||||
if err := cl.WritePacket(*pk); err != nil {
|
||||
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
|
||||
}
|
||||
atomic.AddInt32(&cl.State.outboundQty, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseConnect parses the connect parameters and properties for a client.
|
||||
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Net.Listener = lid
|
||||
|
||||
cl.Properties.ProtocolVersion = pk.ProtocolVersion
|
||||
cl.Properties.Username = pk.Connect.Username
|
||||
cl.Properties.Clean = pk.Connect.Clean
|
||||
cl.Properties.Props = pk.Properties.Copy(false)
|
||||
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
|
||||
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
|
||||
|
||||
cl.ID = pk.Connect.ClientIdentifier
|
||||
if cl.ID == "" {
|
||||
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
|
||||
cl.Properties.Props.AssignedClientID = cl.ID
|
||||
}
|
||||
|
||||
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
|
||||
if pk.Connect.Keepalive > 0 {
|
||||
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
}
|
||||
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will = Will{
|
||||
Qos: pk.Connect.WillQos,
|
||||
Retain: pk.Connect.WillRetain,
|
||||
Payload: pk.Connect.WillPayload,
|
||||
TopicName: pk.Connect.WillTopic,
|
||||
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
|
||||
User: pk.Connect.WillProperties.User,
|
||||
}
|
||||
if pk.Properties.SessionExpiryIntervalFlag &&
|
||||
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
|
||||
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
|
||||
}
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will.Flag = 1 // atomic for checking
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
}
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
}
|
||||
}
|
||||
|
||||
// NextPacketID returns the next available (unused) packet id for the client.
|
||||
// If no unused packet ids are available, an error is returned and the client
|
||||
// should be disconnected.
|
||||
func (cl *Client) NextPacketID() (i uint32, err error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
i = atomic.LoadUint32(&cl.State.packetID)
|
||||
started := i
|
||||
overflowed := false
|
||||
for {
|
||||
if overflowed && i == started {
|
||||
return 0, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
if i >= cl.ops.options.Capabilities.maximumPacketID {
|
||||
overflowed = true
|
||||
i = 0
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
|
||||
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
|
||||
func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
if cl.State.Inflight.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if tk.FixedHeader.Type == packets.Publish {
|
||||
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
|
||||
}
|
||||
|
||||
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
|
||||
err := cl.WritePacket(tk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosComplete(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted = append(deleted, tk.PacketID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
// Read reads incoming packets from the connected client and transforms them into
|
||||
// packets to be handled by the packetHandler.
|
||||
func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = packetHandler(cl, pk) // Process inbound packet.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop(err error) {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.Close() // omit close error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cl.State.stopCause.Store(err)
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.done, 1)
|
||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
|
||||
// StopCause returns the reason the client connection was stopped, if any.
|
||||
func (cl *Client) StopCause() error {
|
||||
if cl.State.stopCause.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// Closed returns true if client connection is closed.
|
||||
func (cl *Client) Closed() bool {
|
||||
return atomic.LoadUint32(&cl.State.done) == 1
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
if cl.Net.bconn == nil {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
b, err := cl.Net.bconn.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fh.Decode(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bu int
|
||||
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
|
||||
|
||||
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
|
||||
pk.FixedHeader = *fh
|
||||
p := make([]byte, pk.FixedHeader.Remaining)
|
||||
n, err := io.ReadFull(cl.Net.bconn, p)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
px := append([]byte{}, p[:]...)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectDecode(px)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectDecode(px)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackDecode(px)
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecDecode(px)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelDecode(px)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompDecode(px)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeDecode(px)
|
||||
case packets.Suback:
|
||||
err = pk.SubackDecode(px)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(px)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackDecode(px)
|
||||
case packets.Pingreq:
|
||||
case packets.Pingresp:
|
||||
case packets.Auth:
|
||||
err = pk.AuthDecode(px)
|
||||
default:
|
||||
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
|
||||
return
|
||||
}
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pk.Expiry > 0 {
|
||||
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
|
||||
}
|
||||
|
||||
pk.ProtocolVersion = cl.Properties.ProtocolVersion
|
||||
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
|
||||
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
|
||||
}
|
||||
|
||||
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
|
||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||
}
|
||||
|
||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||
|
||||
var err error
|
||||
buf := new(bytes.Buffer)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case packets.Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case packets.Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case packets.Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
case packets.Auth:
|
||||
err = pk.AuthEncode(buf)
|
||||
default:
|
||||
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
|
||||
}
|
||||
|
||||
nb := net.Buffers{buf.Bytes()}
|
||||
n, err := nb.WriteTo(cl.Net.Conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesSent, n)
|
||||
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
|
||||
if pk.FixedHeader.Type == packets.Publish {
|
||||
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
|
||||
}
|
||||
|
||||
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -1,745 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const pkInfo = "packet type %v, %s"
|
||||
|
||||
var errClientStop = errors.New("test stop")
|
||||
|
||||
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: &logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
cl.ID = "mochi"
|
||||
cl.State.Inflight.maximumSendQuota = 5
|
||||
cl.State.Inflight.sendQuota = 5
|
||||
cl.State.Inflight.maximumReceiveQuota = 10
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
go cl.WriteLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestNewInflights(t *testing.T) {
|
||||
require.NotNil(t, NewInflights().internal)
|
||||
}
|
||||
|
||||
func TestNewClients(t *testing.T) {
|
||||
cl := NewClients()
|
||||
require.NotNil(t, cl.internal)
|
||||
}
|
||||
|
||||
func TestClientsAdd(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
}
|
||||
|
||||
func TestClientsGet(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
client, ok := cl.Get("t1")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, "t1", client.ID)
|
||||
}
|
||||
|
||||
func TestClientsGetAll(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Contains(t, cl.internal, "t3")
|
||||
|
||||
clients := cl.GetAll()
|
||||
require.Len(t, clients, 3)
|
||||
}
|
||||
|
||||
func TestClientsLen(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Equal(t, 2, cl.Len())
|
||||
}
|
||||
|
||||
func TestClientsDelete(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
|
||||
cl.Delete("t1")
|
||||
_, ok := cl.Get("t1")
|
||||
require.Equal(t, false, ok)
|
||||
require.Nil(t, cl.internal["t1"])
|
||||
}
|
||||
|
||||
func TestClientsGetByListener(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
clients := cl.GetByListener("tcp1")
|
||||
require.NotEmpty(t, clients)
|
||||
require.Equal(t, 1, len(clients))
|
||||
require.Equal(t, "tcp1", clients[0].Net.Listener)
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.NotNil(t, cl.ops.options.Capabilities)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestClientParseConnect(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillTopic: "lwt",
|
||||
WillPayload: []byte("lol gg"),
|
||||
WillQos: 1,
|
||||
WillRetain: false,
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
ReceiveMaximum: uint16(5),
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
|
||||
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
|
||||
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillProperties: packets.Properties{
|
||||
WillDelayInterval: 200,
|
||||
},
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
SessionExpiryInterval: 100,
|
||||
SessionExpiryIntervalFlag: true,
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
|
||||
}
|
||||
|
||||
func TestClientParseConnectNoID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ParseConnect("tcp1", packets.Packet{})
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
// skip over 2
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), i)
|
||||
|
||||
// Skip over overflow
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
|
||||
atomic.StoreUint32(&cl.State.packetID, 65534)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
|
||||
}
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
require.Equal(t, uint32(0), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
|
||||
}
|
||||
|
||||
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
|
||||
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
|
||||
_, err = cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
}
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
deleted := cl.ClearInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
go func() {
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, cl.State.Inflight.Len())
|
||||
require.Equal(t, pk1.RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newTestClient()
|
||||
r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 18, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
var pks []packets.Packet
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
pks = append(pks, pk)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
err := <-o
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.Equal(t, 2, len(pks))
|
||||
require.Equal(t, []packets.Packet{
|
||||
{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 18,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 11,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("yeah"),
|
||||
},
|
||||
}, pks)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.done = 1
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
require.NoError(t, <-o)
|
||||
}
|
||||
|
||||
func TestClientStop(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||
require.Equal(t, uint32(1), cl.State.done)
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientClosed(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
require.False(t, cl.Closed())
|
||||
cl.Stop(nil)
|
||||
require.True(t, cl.Closed())
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
cl.Net.bconn = nil
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrConnectionClosed, err)
|
||||
}
|
||||
|
||||
func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
err := cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return errors.New("test")
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pk)
|
||||
|
||||
require.Equal(t, packets.Packet{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 11,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("yeah"),
|
||||
}, pk)
|
||||
}
|
||||
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
for _, tx := range pkTable {
|
||||
tt := tx // avoid data race
|
||||
t.Run(tt.Desc, func(t *testing.T) {
|
||||
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
|
||||
go func() {
|
||||
r.Write(tt.RawBytes)
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.Packet.ProtocolVersion == 5 {
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
} else {
|
||||
cl.Properties.ProtocolVersion = 0
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
if tt.Packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
}
|
||||
|
||||
func TestClientWritePacket(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
|
||||
|
||||
o := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
o <- buf
|
||||
}()
|
||||
|
||||
err := cl.WritePacket(*tt.Packet)
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
cl.Stop(errClientStop)
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
|
||||
// The stop cause is either the test error, EOF, or a
|
||||
// closed pipe, depending on which goroutine runs first.
|
||||
err = cl.StopCause()
|
||||
require.True(t,
|
||||
errors.Is(err, errClientStop) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, io.ErrClosedPipe))
|
||||
|
||||
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
|
||||
if tt.Packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
|
||||
err := cl.WritePacket(pk)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Type: 0,
|
||||
Remaining: 11,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Remaining: 1,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(errClientStop)
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrConnectionClosed, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketInvalidPacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.WritePacket(packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
var (
|
||||
pkTable = []packets.TPacketCase{
|
||||
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
|
||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
|
||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
|
||||
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
|
||||
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
|
||||
packets.TPacketData[packets.Puback].Get(packets.TPuback),
|
||||
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
|
||||
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
|
||||
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
|
||||
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
|
||||
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
|
||||
packets.TPacketData[packets.Suback].Get(packets.TSuback),
|
||||
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
|
||||
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
|
||||
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
|
||||
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
|
||||
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
|
||||
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
|
||||
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
|
||||
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
|
||||
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
|
||||
packets.TPacketData[packets.Auth].Get(packets.TAuth),
|
||||
}
|
||||
)
|
||||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"flag"
|
||||
rv8 "github.com/go-redis/redis/v8"
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/rs/zerolog"
|
||||
|
|
@ -15,42 +16,12 @@ import (
|
|||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
var (
|
||||
// testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
//MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
//VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
//BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
//VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
//DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
//AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
//OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
//MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
//gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
//qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
//zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
//-----END CERTIFICATE-----`)
|
||||
//
|
||||
// testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
//MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
//FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
//rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
//AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
//UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
//n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
//mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
//INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
//AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
///F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
//WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
//w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
//OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
//-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
//TODO: create custom mqtt server options
|
||||
server = mqtt.New(&mqtt.Options{
|
||||
//Capabilities: &mqtt.Capabilities{
|
||||
// ServerKeepAlive: 10000,
|
||||
|
|
@ -233,6 +204,7 @@ func (h *MyHook) Provides(b byte) bool {
|
|||
mqtt.OnSubscribed,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnPacketEncode,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
|
|
@ -284,3 +256,15 @@ func (h *MyHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []
|
|||
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MyHook) OnPacketEncode(cl *mqtt.Client, pk packets.Packet) packets.Packet {
|
||||
var clUser string
|
||||
if len(cl.Properties.Props.User) > 0 {
|
||||
clUser = cl.Properties.Props.User[0].Val
|
||||
}
|
||||
if pk.FixedHeader.Type == packets.Connack {
|
||||
pk.Properties.User = []packets.UserProperty{{Key: "subscribe-topic", Val: "oktopus/v1/agent/" + clUser}}
|
||||
}
|
||||
|
||||
return pk
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
authRules := &auth.Ledger{
|
||||
Auth: auth.AuthRules{ // Auth disallows all by default
|
||||
{Username: "peach", Password: "password1", Allow: true},
|
||||
{Username: "melon", Password: "password2", Allow: true},
|
||||
{Remote: "127.0.0.1:*", Allow: true},
|
||||
{Remote: "localhost:*", Allow: true},
|
||||
},
|
||||
ACL: auth.ACLRules{ // ACL allows all by default
|
||||
{Remote: "127.0.0.1:*"}, // local superuser allow all
|
||||
{
|
||||
// user melon can read and write to their own topic
|
||||
Username: "melon", Filters: auth.Filters{
|
||||
"melon/#": auth.ReadWrite,
|
||||
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
|
||||
},
|
||||
},
|
||||
{
|
||||
// Otherwise, no clients have publishing permissions
|
||||
Filters: auth.Filters{
|
||||
"#": auth.ReadOnly,
|
||||
"updates/#": auth.Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// you may also find this useful...
|
||||
// d, _ := authRules.ToYAML()
|
||||
// d, _ := authRules.ToJSON()
|
||||
// fmt.Println(string(d))
|
||||
|
||||
server := mqtt.New(nil)
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Ledger: authRules,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
{
|
||||
"auth": [
|
||||
{
|
||||
"username": "leandro",
|
||||
"password": "leandro",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"username": "steve",
|
||||
"password": "steve",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"username": "root",
|
||||
"password": "root",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"remote": "*",
|
||||
"allow": false
|
||||
},
|
||||
{
|
||||
"remote": "*",
|
||||
"allow": false
|
||||
}
|
||||
],
|
||||
"acl": [
|
||||
{
|
||||
"remote": "*"
|
||||
},
|
||||
{
|
||||
"username": "leandro",
|
||||
"filters": {
|
||||
"oktopus/+/agent/+": 1,
|
||||
"oktopus/+/controller/+": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"username": "steve",
|
||||
"filters": {
|
||||
"oktopus/+/agent/+": 1,
|
||||
"oktopus/+/controller/+": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"username": "root",
|
||||
"filters": {
|
||||
"#": 3
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
auth:
|
||||
- username: peach
|
||||
password: password1
|
||||
allow: true
|
||||
- username: melon
|
||||
password: password2
|
||||
allow: true
|
||||
# - remote: 127.0.0.1:*
|
||||
# allow: true
|
||||
# - remote: localhost:*
|
||||
# allow: true
|
||||
acl:
|
||||
# 0 = deny, 1 = read only, 2 = write only, 3 = read and write
|
||||
- remote: 127.0.0.1:*
|
||||
- username: melon
|
||||
filters:
|
||||
melon/#: 3
|
||||
updates/#: 2
|
||||
- filters:
|
||||
'#': 1
|
||||
updates/#: 0
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// You can also run from top-level server.go folder:
|
||||
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml
|
||||
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json
|
||||
path := flag.String("path", "auth.yaml", "path to data auth file")
|
||||
flag.Parse()
|
||||
|
||||
// Get ledger from yaml file
|
||||
data, err := os.ReadFile(*path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
server := mqtt.New(nil)
|
||||
err = server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // build ledger from byte slice, yaml or json
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
|
||||
flag.Parse()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/debug"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
err := server.AddHook(new(auth.AllowHook), nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(debug.Hook), &debug.Options{
|
||||
// ShowPacketData: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(ExampleHook), map[string]any{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start the server
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Demonstration of directly publishing messages to a topic via the
|
||||
// `server.Publish` method. Subscribe to `direct/publish` using your
|
||||
// MQTT client to see the messages.
|
||||
go func() {
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
for range time.Tick(time.Second * 1) {
|
||||
err := server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("injected scheduled message"),
|
||||
})
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.InjectPacket")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go injected packet to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
// There is also a shorthand convenience function, Publish, for easily sending
|
||||
// publish packets if you are not concerned with creating your own packets.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 5) {
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.Publish")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
||||
type ExampleHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
func (h *ExampleHook) ID() string {
|
||||
return "events-example"
|
||||
}
|
||||
|
||||
func (h *ExampleHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnect,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnPublished,
|
||||
mqtt.OnPublish,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
func (h *ExampleHook) Init(config any) error {
|
||||
h.Log.Info().Msg("initialised")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
|
||||
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
|
||||
|
||||
pkx := pk
|
||||
if string(pk.Payload) == "hello" {
|
||||
pkx.Payload = []byte("hello world")
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
|
||||
}
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.ServerKeepAlive = 60
|
||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||
|
||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
||||
type pahoAuthHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) ID() string {
|
||||
return "allow-all-auth"
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return topic != "test/nosubscribe"
|
||||
}
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/badger"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
badgerPath := ".badger"
|
||||
defer os.RemoveAll(badgerPath) // remove the example badger files at the end
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
err := server.AddHook(new(badger.Hook), &badger.Options{
|
||||
Path: badgerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
}
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/bolt"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
err := server.AddHook(new(bolt.Hook), &bolt.Options{
|
||||
Path: "bolt.db",
|
||||
Options: &bbolt.Options{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
rv8 "github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
Addr: "localhost:6379", // default redis address
|
||||
Password: "", // your password
|
||||
DB: 0, // your redis db
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
}
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// An example of configuring various server options...
|
||||
options := &mqtt.Options{
|
||||
// InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
|
||||
}
|
||||
|
||||
server := mqtt.New(options)
|
||||
|
||||
// For security reasons, the default implementation disallows all connections.
|
||||
// If you want to allow all connections, you must specifically allow it.
|
||||
err := server.AddHook(new(auth.AllowHook), nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
var (
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
// Optionally, if you want clients to authenticate only with certs issued by your CA,
|
||||
// you might want to use something like this:
|
||||
// certPool := x509.NewCertPool()
|
||||
// _ = certPool.AppendCertsFromPEM(caCertPem)
|
||||
// tlsConfig := &tls.Config{
|
||||
// ClientCAs: certPool,
|
||||
// ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// }
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
err = server.AddListener(ws)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
}, nil)
|
||||
err = server.AddListener(stats)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882", nil)
|
||||
err := server.AddListener(ws)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
|
@ -1,40 +1,20 @@
|
|||
module github.com/mochi-co/mqtt/v2
|
||||
module broker
|
||||
|
||||
go 1.19
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.23.0
|
||||
github.com/asdine/storm v2.1.2+incompatible
|
||||
github.com/asdine/storm/v3 v3.2.1
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
github.com/rs/xid v1.4.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/timshannon/badgerhold v1.0.0
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
github.com/mochi-co/mqtt/v2 v2.2.16
|
||||
github.com/rs/zerolog v1.29.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/badger v1.6.0 // indirect
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/golang/snappy v0.0.3 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
|
||||
golang.org/x/net v0.7.0 // indirect
|
||||
github.com/rs/xid v1.4.0 // indirect
|
||||
golang.org/x/sys v0.5.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,143 +1,44 @@
|
|||
github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M=
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
|
||||
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
|
||||
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
|
||||
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
|
||||
github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
|
||||
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
|
||||
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
|
||||
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
|
||||
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo=
|
||||
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
||||
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/mochi-co/mqtt/v2 v2.2.16 h1:CBqbxFhExzASNjj4BjSei0hYY1F5N5IeDqNVhjN+tp8=
|
||||
github.com/mochi-co/mqtt/v2 v2.2.16/go.mod h1:MDMTThFgWj/LjJ6wc51bP5l4xnJG/ahpc9tR9vZVf8Q=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
|
||||
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
|
||||
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
|
||||
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
|
||||
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
|
||||
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc=
|
||||
github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/timshannon/badgerhold v1.0.0 h1:LtqnDRVP7294FWRiZCIfQa6Tt0bGmlzbO8c364QC2Y8=
|
||||
github.com/timshannon/badgerhold v1.0.0/go.mod h1:Vv2Jj0PAfzqViEpGvJzLP8PY07x1iXLgKRuLY7bqPOE=
|
||||
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
|
||||
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
|
||||
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
|
||||
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
|
||||
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
|||
|
|
@ -1,794 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
SetOptions byte = iota
|
||||
OnSysInfoTick
|
||||
OnStarted
|
||||
OnStopped
|
||||
OnConnectAuthenticate
|
||||
OnACLCheck
|
||||
OnConnect
|
||||
OnSessionEstablished
|
||||
OnDisconnect
|
||||
OnAuthPacket
|
||||
OnPacketRead
|
||||
OnPacketEncode
|
||||
OnPacketSent
|
||||
OnPacketProcessed
|
||||
OnSubscribe
|
||||
OnSubscribed
|
||||
OnSelectSubscribers
|
||||
OnUnsubscribe
|
||||
OnUnsubscribed
|
||||
OnPublish
|
||||
OnPublished
|
||||
OnPublishDropped
|
||||
OnRetainMessage
|
||||
OnQosPublish
|
||||
OnQosComplete
|
||||
OnQosDropped
|
||||
OnWill
|
||||
OnWillSent
|
||||
OnClientExpired
|
||||
OnRetainedExpired
|
||||
StoredClients
|
||||
StoredSubscriptions
|
||||
StoredInflightMessages
|
||||
StoredRetainedMessages
|
||||
StoredSysInfo
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
|
||||
ErrInvalidConfigType = errors.New("invalid config type provided")
|
||||
)
|
||||
|
||||
// Hook provides an interface of handlers for different events which occur
|
||||
// during the lifecycle of the broker.
|
||||
type Hook interface {
|
||||
ID() string
|
||||
Provides(b byte) bool
|
||||
Init(config any) error
|
||||
Stop() error
|
||||
SetOpts(l *zerolog.Logger, o *HookOptions)
|
||||
OnStarted()
|
||||
OnStopped()
|
||||
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
|
||||
OnACLCheck(cl *Client, topic string, write bool) bool
|
||||
OnSysInfoTick(*system.Info)
|
||||
OnConnect(cl *Client, pk packets.Packet)
|
||||
OnSessionEstablished(cl *Client, pk packets.Packet)
|
||||
OnDisconnect(cl *Client, err error, expire bool)
|
||||
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
|
||||
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
|
||||
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
|
||||
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
|
||||
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
|
||||
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
|
||||
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
|
||||
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
|
||||
OnUnsubscribed(cl *Client, pk packets.Packet)
|
||||
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPublished(cl *Client, pk packets.Packet)
|
||||
OnPublishDropped(cl *Client, pk packets.Packet)
|
||||
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
|
||||
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
|
||||
OnQosComplete(cl *Client, pk packets.Packet)
|
||||
OnQosDropped(cl *Client, pk packets.Packet)
|
||||
OnWill(cl *Client, will Will) (Will, error)
|
||||
OnWillSent(cl *Client, pk packets.Packet)
|
||||
OnClientExpired(cl *Client)
|
||||
OnRetainedExpired(filter string)
|
||||
StoredClients() ([]storage.Client, error)
|
||||
StoredSubscriptions() ([]storage.Subscription, error)
|
||||
StoredInflightMessages() ([]storage.Message, error)
|
||||
StoredRetainedMessages() ([]storage.Message, error)
|
||||
StoredSysInfo() (storage.SystemInfo, error)
|
||||
}
|
||||
|
||||
// HookOptions contains values which are inherited from the server on initialisation.
|
||||
type HookOptions struct {
|
||||
Capabilities *Capabilities
|
||||
}
|
||||
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
func (h *Hooks) Len() int64 {
|
||||
return atomic.LoadInt64(&h.qty)
|
||||
}
|
||||
|
||||
// Provides returns true if any one hook provides any of the requested hook methods.
|
||||
func (h *Hooks) Provides(b ...byte) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
for _, hb := range b {
|
||||
if hook.Provides(hb) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Add adds and initializes a new hook.
|
||||
func (h *Hooks) Add(hook Hook, config any) error {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
err := hook.Init(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||
}
|
||||
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
i = []Hook{}
|
||||
}
|
||||
|
||||
i = append(i, hook)
|
||||
h.internal.Store(i)
|
||||
atomic.AddInt64(&h.qty, 1)
|
||||
h.wg.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
return []Hook{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
|
||||
}
|
||||
|
||||
h.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
h.wg.Wait()
|
||||
}
|
||||
|
||||
// OnSysInfoTick is called when the $SYS topic values are published out.
|
||||
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSysInfoTick) {
|
||||
hook.OnSysInfoTick(sys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnStarted is called when the server has successfully started.
|
||||
func (h *Hooks) OnStarted() {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStarted) {
|
||||
hook.OnStarted()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnStopped is called when the server has successfully stopped.
|
||||
func (h *Hooks) OnStopped() {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStopped) {
|
||||
hook.OnStopped()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnect) {
|
||||
hook.OnConnect(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablished) {
|
||||
hook.OnSessionEstablished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnDisconnect) {
|
||||
hook.OnDisconnect(cl, err, expire)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a packet is received from a client.
|
||||
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
|
||||
// to create their own auth packet handling mechanisms.
|
||||
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnAuthPacket) {
|
||||
npk, err := hook.OnAuthPacket(cl, pkx)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
|
||||
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketEncode) {
|
||||
pk = hook.OnPacketEncode(cl, pk)
|
||||
}
|
||||
}
|
||||
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
|
||||
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketProcessed) {
|
||||
hook.OnPacketProcessed(cl, pk, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
|
||||
// containing the bytes sent.
|
||||
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketSent) {
|
||||
hook.OnPacketSent(cl, pk, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribe is called when a client subscribes to one or more filters. This method
|
||||
// differs from OnSubscribed in that it allows you to modify the subscription values
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribe) {
|
||||
pk = hook.OnSubscribe(cl, pk)
|
||||
}
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribed) {
|
||||
hook.OnSubscribed(cl, pk, reasonCodes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
|
||||
// shared subscription subscribers have been selected. This hook can be used to programmatically
|
||||
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
|
||||
// group in a custom manner (such as based on client id, ip, etc).
|
||||
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSelectSubscribers) {
|
||||
subs = hook.OnSelectSubscribers(subs, pk)
|
||||
}
|
||||
}
|
||||
return subs
|
||||
}
|
||||
|
||||
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
|
||||
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribe) {
|
||||
pk = hook.OnUnsubscribe(cl, pk)
|
||||
}
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribed) {
|
||||
hook.OnUnsubscribed(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublish) {
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublished) {
|
||||
hook.OnPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublishDropped is called when a message to a client was dropped instead of delivered
|
||||
// such as when a client is too slow to respond.
|
||||
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublishDropped) {
|
||||
hook.OnPublishDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainMessage) {
|
||||
hook.OnRetainMessage(cl, pk, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
|
||||
// In other words, this method is called when a new inflight message is created or resent.
|
||||
// It is typically used to store a new inflight message.
|
||||
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosPublish) {
|
||||
hook.OnQosPublish(cl, pk, sent, resends)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
// In other words, when an inflight message is resolved.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosComplete) {
|
||||
hook.OnQosComplete(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
hook.OnQosDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message. This method
|
||||
// differs from OnWillSent in that it allows you to modify the LWT message before it is
|
||||
// published. The return values of the hook methods are passed-through in the order
|
||||
// the hooks were attached.
|
||||
func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
|
||||
continue
|
||||
}
|
||||
will = mlwt
|
||||
}
|
||||
}
|
||||
|
||||
return will
|
||||
}
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWillSent) {
|
||||
hook.OnWillSent(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired is called when a client session has expired and should be deleted.
|
||||
func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnClientExpired) {
|
||||
hook.OnClientExpired(cl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when a retained message has expired and should be deleted.
|
||||
func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainedExpired) {
|
||||
hook.OnRetainedExpired(filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all clients, e.g. from a persistent store, is used to
|
||||
// populate the server clients list before start.
|
||||
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
|
||||
// used to populate the server subscriptions list before start.
|
||||
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
|
||||
// and is used to populate the restored clients with inflight messages before start.
|
||||
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
|
||||
// and is used to populate the server topics with retained messages before start.
|
||||
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if v.Version != "" {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
|
||||
// An implementation of this method MUST be used to allow or deny access to the
|
||||
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check connecting users against an existing user database.
|
||||
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnectAuthenticate) {
|
||||
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
|
||||
// An implementation of this method MUST be used to allow or deny access to the
|
||||
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check publishing and subscribing users against an existing permissions or roles database.
|
||||
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnACLCheck) {
|
||||
if ok := hook.OnACLCheck(cl, topic, write); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
Hook
|
||||
Log *zerolog.Logger
|
||||
Opts *HookOptions
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *HookBase) ID() string {
|
||||
return "base"
|
||||
}
|
||||
|
||||
// Provides indicates which methods a hook provides. The default is none - this method
|
||||
// should be overridden by the embedding hook.
|
||||
func (h *HookBase) Provides(b byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Init performs any pre-start initializations for the hook, such as connecting to databases
|
||||
// or opening files.
|
||||
func (h *HookBase) Init(config any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOpts is called by the server to propagate internal values and generally should
|
||||
// not be called manually.
|
||||
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
|
||||
h.Log = l
|
||||
h.Opts = opts
|
||||
}
|
||||
|
||||
// Stop is called to gracefully shutdown the hook.
|
||||
func (h *HookBase) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *HookBase) OnStarted() {}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *HookBase) OnStopped() {}
|
||||
|
||||
// OnSysInfoTick is called when the server publishes system info.
|
||||
func (h *HookBase) OnSysInfoTick(*system.Info) {}
|
||||
|
||||
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
|
||||
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
|
||||
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
|
||||
|
||||
// OnAuthPacket is called when an auth packet is received from the client.
|
||||
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a packet is received.
|
||||
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
|
||||
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnPacketSent is called immediately after a packet is written to a client.
|
||||
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
|
||||
|
||||
// OnPacketProcessed is called immediately after a packet from a client is processed.
|
||||
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
|
||||
|
||||
// OnSubscribe is called when a client subscribes to one or more filters.
|
||||
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
|
||||
|
||||
// OnSelectSubscribers is called when selecting subscribers to receive a message.
|
||||
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
return subs
|
||||
}
|
||||
|
||||
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
|
||||
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublish is called when a client publishes a message.
|
||||
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
|
||||
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
|
||||
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message.
|
||||
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
return will, nil
|
||||
}
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnClientExpired is called when a client session has expired.
|
||||
func (h *HookBase) OnClientExpired(cl *Client) {}
|
||||
|
||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||
|
||||
// StoredClients returns all clients from a store.
|
||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all subcriptions from a store.
|
||||
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all inflight messages from a store.
|
||||
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all retained messages from a store.
|
||||
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
return
|
||||
}
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
// AllowHook is an authentication hook which allows connection access
|
||||
// for all users and read and write access to all topics.
|
||||
type AllowHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *AllowHook) ID() string {
|
||||
return "allow-all-auth"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *AllowHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate returns true/allowed for all requests.
|
||||
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// OnACLCheck returns true/allowed for all checks.
|
||||
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowAllID(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.Equal(t, "allow-all-auth", h.ID())
|
||||
}
|
||||
|
||||
func TestAllowAllProvides(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
require.False(t, h.Provides(mqtt.OnPublished))
|
||||
}
|
||||
|
||||
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
|
||||
}
|
||||
|
||||
func TestAllowAllOnACLCheck(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
|
||||
}
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
// Options contains the configuration/rules data for the auth ledger.
|
||||
type Options struct {
|
||||
Data []byte
|
||||
Ledger *Ledger
|
||||
}
|
||||
|
||||
// Hook is an authentication hook which implements an auth ledger.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
ledger *Ledger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "auth-ledger"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init configures the hook with the auth ledger to be used for checking.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
|
||||
var err error
|
||||
if h.config.Ledger != nil {
|
||||
h.ledger = h.config.Ledger
|
||||
} else if len(h.config.Data) > 0 {
|
||||
h.ledger = new(Ledger)
|
||||
err = h.ledger.Unmarshal(h.config.Data)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.ledger == nil {
|
||||
h.ledger = &Ledger{
|
||||
Auth: AuthRules{},
|
||||
ACL: ACLRules{},
|
||||
}
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Int("authentication", len(h.ledger.Auth)).
|
||||
Int("acl", len(h.ledger.ACL)).
|
||||
Msg("loaded auth rules")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
|
||||
// in the auth ledger.
|
||||
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
if _, ok := h.ledger.AuthOk(cl, pk); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("username", string(pk.Connect.Username)).
|
||||
Str("remote", cl.Net.Remote).
|
||||
Msg("client failed authentication check")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
|
||||
// or publish to a given topic.
|
||||
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Debug().
|
||||
Str("client", cl.ID).
|
||||
Str("username", string(cl.Properties.Username)).
|
||||
Str("topic", topic).
|
||||
Msg("client failed allowed ACL check")
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
// func teardown(t *testing.T, path string, h *Hook) {
|
||||
// h.Stop()
|
||||
// }
|
||||
|
||||
func TestBasicID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "auth-ledger", h.ID())
|
||||
}
|
||||
|
||||
func TestBasicProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
require.False(t, h.Provides(mqtt.OnPublish))
|
||||
}
|
||||
|
||||
func TestBasicInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
ln := &Ledger{
|
||||
Auth: []AuthRule{
|
||||
{
|
||||
Remote: "127.0.0.1",
|
||||
Allow: true,
|
||||
},
|
||||
},
|
||||
ACL: []ACLRule{
|
||||
{
|
||||
Remote: "127.0.0.1",
|
||||
Filters: Filters{
|
||||
"#": ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := h.Init(&Options{
|
||||
Ledger: ln,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Same(t, ln, h.ledger)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: ledgerJSON,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
|
||||
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: ledgerYAML,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
|
||||
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: []byte("fdsfdsafasd"),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
ln.ACL = checkLedger.ACL
|
||||
err := h.Init(
|
||||
&Options{
|
||||
Ledger: ln,
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
))
|
||||
|
||||
require.False(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
))
|
||||
|
||||
require.False(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{},
|
||||
packets.Packet{},
|
||||
))
|
||||
}
|
||||
|
||||
func TestOnACL(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
ln.ACL = checkLedger.ACL
|
||||
err := h.Init(
|
||||
&Options{
|
||||
Ledger: ln,
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"mochi/info",
|
||||
true,
|
||||
))
|
||||
|
||||
require.False(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"d/j/f",
|
||||
true,
|
||||
))
|
||||
|
||||
require.True(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"readonly",
|
||||
false,
|
||||
))
|
||||
|
||||
require.False(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"readonly",
|
||||
true,
|
||||
))
|
||||
}
|
||||
|
|
@ -1,231 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
Deny Access = iota // user cannot access the topic
|
||||
ReadOnly // user can only subscribe to the topic
|
||||
WriteOnly // user can only publish to the topic
|
||||
ReadWrite // user can both publish and subscribe to the topic
|
||||
)
|
||||
|
||||
// Access determines the read/write privileges for an ACL rule.
|
||||
type Access byte
|
||||
|
||||
// Users contains a map of access rules for specific users, keyed on username.
|
||||
type Users map[string]UserRule
|
||||
|
||||
// UserRule defines a set of access rules for a specific user.
|
||||
type UserRule struct {
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
|
||||
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
|
||||
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
|
||||
}
|
||||
|
||||
// AuthRules defines generic access rules applicable to all users.
|
||||
type AuthRules []AuthRule
|
||||
|
||||
type AuthRule struct {
|
||||
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
|
||||
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
|
||||
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
|
||||
}
|
||||
|
||||
// ACLRules defines generic topic or filter access rules applicable to all users.
|
||||
type ACLRules []ACLRule
|
||||
|
||||
// ACLRule defines access rules for a specific topic or filter.
|
||||
type ACLRule struct {
|
||||
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
|
||||
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
|
||||
}
|
||||
|
||||
// Filters is a map of Access rules keyed on filter.
|
||||
type Filters map[RString]Access
|
||||
|
||||
// RString is a rule value string.
|
||||
type RString string
|
||||
|
||||
// Matches returns true if the rule matches a given string.
|
||||
func (r RString) Matches(a string) bool {
|
||||
rr := string(r)
|
||||
if r == "" || r == "*" || a == rr {
|
||||
return true
|
||||
}
|
||||
|
||||
i := strings.Index(rr, "*")
|
||||
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterMatches returns true if a filter matches a topic rule.
|
||||
func (f RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(f), a)
|
||||
return ok
|
||||
}
|
||||
|
||||
// MatchTopic checks if a given topic matches a filter, accounting for filter
|
||||
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
|
||||
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
|
||||
filterParts := strings.Split(filter, "/")
|
||||
topicParts := strings.Split(topic, "/")
|
||||
|
||||
elements = make([]string, 0)
|
||||
for i := 0; i < len(filterParts); i++ {
|
||||
if i >= len(topicParts) {
|
||||
matched = false
|
||||
return
|
||||
}
|
||||
|
||||
if filterParts[i] == "+" {
|
||||
elements = append(elements, topicParts[i])
|
||||
continue
|
||||
}
|
||||
|
||||
if filterParts[i] == "#" {
|
||||
matched = true
|
||||
elements = append(elements, strings.Join(topicParts[i:], "/"))
|
||||
return
|
||||
}
|
||||
|
||||
if filterParts[i] != topicParts[i] {
|
||||
matched = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return elements, true
|
||||
}
|
||||
|
||||
// Ledger is an auth ledger containing access rules for users and topics.
|
||||
type Ledger struct {
|
||||
sync.Mutex `json:"-" yaml:"-"`
|
||||
Users Users `json:"users" yaml:"users"`
|
||||
Auth AuthRules `json:"auth" yaml:"auth"`
|
||||
ACL ACLRules `json:"acl" yaml:"acl"`
|
||||
}
|
||||
|
||||
// Update updates the internal values of the ledger.
|
||||
func (l *Ledger) Update(ln *Ledger) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Auth = ln.Auth
|
||||
l.ACL = ln.ACL
|
||||
}
|
||||
|
||||
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
|
||||
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
if l.Users != nil {
|
||||
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
|
||||
u.Password != "" &&
|
||||
u.Password == RString(pk.Connect.Password) {
|
||||
return 0, !u.Disallow
|
||||
}
|
||||
}
|
||||
|
||||
// If there's no users map, or no user was found, attempt to find a matching
|
||||
// rule (which may also contain a user).
|
||||
for n, rule := range l.Auth {
|
||||
if rule.Client.Matches(cl.ID) &&
|
||||
rule.Username.Matches(string(cl.Properties.Username)) &&
|
||||
rule.Password.Matches(string(pk.Connect.Password)) &&
|
||||
rule.Remote.Matches(cl.Net.Remote) {
|
||||
return n, rule.Allow
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// ACLOk returns true if the rules indicate the user is allowed to read or write to
|
||||
// a specific filter or topic respectively, based on the write bool.
|
||||
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
if l.Users != nil {
|
||||
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
|
||||
for filter, access := range u.ACL {
|
||||
if filter.FilterMatches(topic) {
|
||||
if !write && (access == ReadOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else if write && (access == WriteOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for n, rule := range l.ACL {
|
||||
if rule.Client.Matches(cl.ID) &&
|
||||
rule.Username.Matches(string(cl.Properties.Username)) &&
|
||||
rule.Remote.Matches(cl.Net.Remote) {
|
||||
if len(rule.Filters) == 0 {
|
||||
return n, true
|
||||
}
|
||||
|
||||
for filter, access := range rule.Filters {
|
||||
if filter.FilterMatches(topic) {
|
||||
if !write && (access == ReadOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else if write && (access == WriteOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, true
|
||||
}
|
||||
|
||||
// ToJSON encodes the values into a JSON string.
|
||||
func (l *Ledger) ToJSON() (data []byte, err error) {
|
||||
return json.Marshal(l)
|
||||
}
|
||||
|
||||
// ToYAML encodes the values into a YAML string.
|
||||
func (l *Ledger) ToYAML() (data []byte, err error) {
|
||||
return yaml.Marshal(l)
|
||||
}
|
||||
|
||||
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
|
||||
func (l *Ledger) Unmarshal(data []byte) error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if data[0] == '{' {
|
||||
return json.Unmarshal(data, l)
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, &l)
|
||||
}
|
||||
|
|
@ -1,610 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
checkLedger = Ledger{
|
||||
Users: Users{ // users are allowed by default
|
||||
"mochi-co": {
|
||||
Password: "melon",
|
||||
ACL: Filters{
|
||||
"d/+/f": Deny,
|
||||
"mochi-co/#": ReadWrite,
|
||||
"readonly": ReadOnly,
|
||||
},
|
||||
},
|
||||
"suspended-username": {
|
||||
Password: "any",
|
||||
Disallow: true,
|
||||
},
|
||||
"mochi": { // ACL only, will defer to AuthRules for authentication
|
||||
ACL: Filters{
|
||||
"special/mochi": ReadOnly,
|
||||
"secret/mochi": Deny,
|
||||
"ignored": ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthRules{
|
||||
{Username: "banned-user"}, // never allow specific username
|
||||
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
|
||||
{Remote: "123.123.123.123"}, // disallow any from specific address
|
||||
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
|
||||
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
|
||||
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
|
||||
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
|
||||
},
|
||||
ACL: ACLRules{
|
||||
{
|
||||
Username: "mochi", // allow matching user/pass
|
||||
Filters: Filters{
|
||||
"a/b/c": Deny,
|
||||
"d/+/f": Deny,
|
||||
"mochi/#": ReadWrite,
|
||||
"updates/#": WriteOnly,
|
||||
"readonly": ReadOnly,
|
||||
"ignored": Deny,
|
||||
},
|
||||
},
|
||||
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
|
||||
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
|
||||
{Remote: "001.002.003.004"}, // Allow all with no filter
|
||||
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestRStringMatches(t *testing.T) {
|
||||
require.True(t, RString("*").Matches("any"))
|
||||
require.True(t, RString("*").Matches(""))
|
||||
require.True(t, RString("").Matches("any"))
|
||||
require.True(t, RString("").Matches(""))
|
||||
require.False(t, RString("no").Matches("any"))
|
||||
require.False(t, RString("no").Matches(""))
|
||||
}
|
||||
|
||||
func TestCanAuthenticate(t *testing.T) {
|
||||
tt := []struct {
|
||||
desc string
|
||||
client *mqtt.Client
|
||||
pk packets.Packet
|
||||
n int
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "allow all local 127.0.0.1",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{}},
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 5,
|
||||
},
|
||||
{
|
||||
desc: "deny username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "allow all local 127.0.0.1",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 5,
|
||||
},
|
||||
{
|
||||
desc: "deny username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "deny client from address",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("not-mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "111.144.155.166",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: false,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
desc: "allow remote wildcard",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "111.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: true,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
desc: "never allow username",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("banned-user"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "matching user in users",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi-co"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "never user in users",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("suspended-user"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range tt {
|
||||
t.Run(d.desc, func(t *testing.T) {
|
||||
n, ok := checkLedger.AuthOk(d.client, d.pk)
|
||||
require.Equal(t, d.n, n)
|
||||
require.Equal(t, d.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanACL(t *testing.T) {
|
||||
tt := []struct {
|
||||
client *mqtt.Client
|
||||
desc string
|
||||
topic string
|
||||
n int
|
||||
write bool
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "allow normal write on any other filter",
|
||||
client: &mqtt.Client{},
|
||||
topic: "default/acl/write/access",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow normal read on any other filter",
|
||||
client: &mqtt.Client{},
|
||||
topic: "default/acl/read/access",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny user on literal filter",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "a/b/c",
|
||||
},
|
||||
{
|
||||
desc: "deny user on partial filter",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "d/j/f",
|
||||
},
|
||||
{
|
||||
desc: "allow read/write to user path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "mochi/read/write",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny read on write-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/no/reading",
|
||||
write: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "deny read on write-only path ext",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/mochi",
|
||||
write: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "allow read on not-acl path (no #)",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow write on write-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/mochi",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny write on read-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "readonly",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "allow read on read-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "readonly",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow $sys access to localhost",
|
||||
client: &mqtt.Client{
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "localhost",
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow $sys access to admin",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("admin"),
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: true,
|
||||
n: 2,
|
||||
},
|
||||
{
|
||||
desc: "deny $sys access to all others",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: false,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
desc: "allow all with no filter",
|
||||
client: &mqtt.Client{
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "001.002.003.004",
|
||||
},
|
||||
},
|
||||
topic: "any/path",
|
||||
write: true,
|
||||
ok: true,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl deny",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "secret/mochi",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl any",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "any/mochi",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl write on read-only",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "special/mochi",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl read on read-only",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "special/mochi",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "preference users embedded acl",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "ignored",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range tt {
|
||||
t.Run(d.desc, func(t *testing.T) {
|
||||
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
|
||||
require.Equal(t, d.n, n)
|
||||
require.Equal(t, d.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchTopic(t *testing.T) {
|
||||
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "d"}, el)
|
||||
|
||||
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "c", "d"}, el)
|
||||
|
||||
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"things/yeah"}, el)
|
||||
|
||||
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
|
||||
|
||||
el, matched = MatchTopic("test", "test")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic("things/stuff//", "things/stuff/")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic("t", "t2")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic(" ", " ")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
}
|
||||
|
||||
var (
|
||||
ledgerStruct = Ledger{
|
||||
Users: Users{
|
||||
"mochi": {
|
||||
Password: "peach",
|
||||
ACL: Filters{
|
||||
"readonly": ReadOnly,
|
||||
"deny": Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthRules{
|
||||
{
|
||||
Client: "*",
|
||||
Username: "mochi-co",
|
||||
Password: "melon",
|
||||
Remote: "192.168.1.*",
|
||||
Allow: true,
|
||||
},
|
||||
},
|
||||
ACL: ACLRules{
|
||||
{
|
||||
Client: "*",
|
||||
Username: "mochi-co",
|
||||
Remote: "127.*",
|
||||
Filters: Filters{
|
||||
"readonly": ReadOnly,
|
||||
"writeonly": WriteOnly,
|
||||
"readwrite": ReadWrite,
|
||||
"deny": Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
|
||||
ledgerYAML = []byte(`users:
|
||||
mochi:
|
||||
password: peach
|
||||
acl:
|
||||
deny: 0
|
||||
readonly: 1
|
||||
auth:
|
||||
- client: '*'
|
||||
username: mochi-co
|
||||
remote: 192.168.1.*
|
||||
password: melon
|
||||
allow: true
|
||||
acl:
|
||||
- client: '*'
|
||||
username: mochi-co
|
||||
remote: 127.*
|
||||
filters:
|
||||
deny: 0
|
||||
readonly: 1
|
||||
readwrite: 3
|
||||
writeonly: 2
|
||||
`)
|
||||
)
|
||||
|
||||
func TestLedgerUpdate(t *testing.T) {
|
||||
old := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
new := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
{Remote: "192.168.*", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
old.Update(new)
|
||||
require.Len(t, old.Auth, 2)
|
||||
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
|
||||
require.NotSame(t, new, old)
|
||||
}
|
||||
|
||||
func TestLedgerToJSON(t *testing.T) {
|
||||
data, err := ledgerStruct.ToJSON()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerJSON, data)
|
||||
}
|
||||
|
||||
func TestLedgerToYAML(t *testing.T) {
|
||||
data, err := ledgerStruct.ToYAML()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerYAML, data)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalFromYAML(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal(ledgerYAML)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &ledgerStruct, l)
|
||||
require.NotSame(t, l, &ledgerStruct)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalFromJSON(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal(ledgerJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &ledgerStruct, l)
|
||||
require.NotSame(t, l, &ledgerStruct)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalNil(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, new(Ledger), l)
|
||||
}
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Options contains configuration settings for the debug output.
|
||||
type Options struct {
|
||||
ShowPacketData bool // include decoded packet data (default false)
|
||||
ShowPings bool // show ping requests and responses (default false)
|
||||
ShowPasswords bool // show connecting user passwords (default false)
|
||||
}
|
||||
|
||||
// Hook is a debugging hook which logs additional low-level information from the server.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
Log *zerolog.Logger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "debug"
|
||||
}
|
||||
|
||||
// Provides indicates that this hook provides all methods.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init is called when the hook is initialized.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOpts is called when the hook receives inheritable server parameters.
|
||||
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
|
||||
h.Log = l
|
||||
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
|
||||
}
|
||||
|
||||
// Stop is called when the hook is stopped.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Debug().Str("method", "Stop").Send()
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *Hook) OnStarted() {
|
||||
h.Log.Debug().Str("method", "OnStarted").Send()
|
||||
}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *Hook) OnStopped() {
|
||||
h.Log.Debug().Str("method", "OnStopped").Send()
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a new packet is received from a client.
|
||||
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet is sent to a client.
|
||||
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
|
||||
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
|
||||
return
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
}
|
||||
|
||||
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
|
||||
}
|
||||
|
||||
// OnLWTSent is called when a will message has been issued from a disconnecting client.
|
||||
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when the server clears expired retained messages.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
|
||||
}
|
||||
|
||||
// OnClientExpired is called when the server clears an expired client.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores clients from a store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores subscriptions from a store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredSubscriptions").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores retained messages from a store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredRetainedMessages").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores inflight messages from a store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredInflightMessages").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores system info from a store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// packetMeta adds additional type-specific metadata to the debug logs.
|
||||
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
|
||||
m := map[string]any{}
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
m["id"] = pk.Connect.ClientIdentifier
|
||||
m["clean"] = pk.Connect.Clean
|
||||
m["keepalive"] = pk.Connect.Keepalive
|
||||
m["version"] = pk.ProtocolVersion
|
||||
m["username"] = string(pk.Connect.Username)
|
||||
if h.config.ShowPasswords {
|
||||
m["password"] = string(pk.Connect.Password)
|
||||
}
|
||||
if pk.Connect.WillFlag {
|
||||
m["will_topic"] = pk.Connect.WillTopic
|
||||
m["will_payload"] = string(pk.Connect.WillPayload)
|
||||
}
|
||||
case packets.Publish:
|
||||
m["topic"] = pk.TopicName
|
||||
m["payload"] = string(pk.Payload)
|
||||
m["raw"] = pk.Payload
|
||||
m["qos"] = pk.FixedHeader.Qos
|
||||
m["id"] = pk.PacketID
|
||||
case packets.Connack:
|
||||
fallthrough
|
||||
case packets.Disconnect:
|
||||
fallthrough
|
||||
case packets.Puback:
|
||||
fallthrough
|
||||
case packets.Pubrec:
|
||||
fallthrough
|
||||
case packets.Pubrel:
|
||||
fallthrough
|
||||
case packets.Pubcomp:
|
||||
m["id"] = pk.PacketID
|
||||
m["reason"] = int(pk.ReasonCode)
|
||||
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
|
||||
m["reason_string"] = pk.Properties.ReasonString
|
||||
}
|
||||
case packets.Subscribe:
|
||||
f := map[string]int{}
|
||||
ids := map[string]int{}
|
||||
for _, v := range pk.Filters {
|
||||
f[v.Filter] = int(v.Qos)
|
||||
ids[v.Filter] = v.Identifier
|
||||
}
|
||||
m["filters"] = f
|
||||
m["subids"] = f
|
||||
|
||||
case packets.Unsubscribe:
|
||||
f := []string{}
|
||||
for _, v := range pk.Filters {
|
||||
f = append(f, v.Filter)
|
||||
}
|
||||
m["filters"] = f
|
||||
case packets.Suback:
|
||||
fallthrough
|
||||
case packets.Unsuback:
|
||||
r := []int{}
|
||||
for _, v := range pk.ReasonCodes {
|
||||
r = append(r, int(v))
|
||||
}
|
||||
m["reasons"] = r
|
||||
case packets.Auth:
|
||||
// tbd
|
||||
}
|
||||
|
||||
if h.config.ShowPacketData {
|
||||
m["packet"] = pk
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
|
@ -1,473 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the badger db file.
|
||||
defaultDbFile = ".badger"
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the BadgerDB instance.
|
||||
type Options struct {
|
||||
Options *badgerhold.Options
|
||||
Path string
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the BadgerDB instance.
|
||||
db *badgerhold.Store // the BadgerDB instance.
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "badger-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init initializes and connects to the badger instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.Path == "" {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
options := badgerhold.DefaultOptions
|
||||
options.Dir = h.config.Path
|
||||
options.ValueDir = h.config.Path
|
||||
options.Logger = h
|
||||
|
||||
var err error
|
||||
h.db, err = badgerhold.Open(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the badger instance.
|
||||
func (h *Hook) Stop() error {
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if their session has expired.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
h.updateClient(cl)
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
PacketID: pk.PacketID,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.ClientKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.SubscriptionKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.RetainedKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Get(storage.SysInfoKey, &v)
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Errorf satisfies the badger interface for an error logger.
|
||||
func (h *Hook) Errorf(m string, v ...interface{}) {
|
||||
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
||||
// Warningf satisfies the badger interface for a warning logger.
|
||||
func (h *Hook) Warningf(m string, v ...interface{}) {
|
||||
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
||||
// Infof satisfies the badger interface for an info logger.
|
||||
func (h *Hook) Infof(m string, v ...interface{}) {
|
||||
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
||||
// Debugf satisfies the badger interface for a debug logger.
|
||||
func (h *Hook) Debugf(m string, v ...interface{}) {
|
||||
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
|
@ -1,681 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
h.db.Badger().Close()
|
||||
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "badger-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.db.Upsert(clientKey, &storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.Get(clientKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.db.Get(clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.Get(retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.Get(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.Get(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.db.Upsert(m.ID, m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.Get(m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.db.Get(m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.Get(inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.db.Get(inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.db.Get(storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.db.Upsert("cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.db.Upsert("sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Upsert(storage.SysInfoKey, &storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Errorf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestWarningf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Warningf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Infof("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Debugf("test", 1, 2, 3)
|
||||
}
|
||||
|
|
@ -1,474 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
sgob "github.com/asdine/storm/codec/gob"
|
||||
"github.com/asdine/storm/v3"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the boltdb file.
|
||||
defaultDbFile = "bolt.db"
|
||||
|
||||
// defaultTimeout is the default time to hold a connection to the file.
|
||||
defaultTimeout = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the bolt instance.
|
||||
type Options struct {
|
||||
Options *bbolt.Options
|
||||
Path string
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using boltdb file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the boltdb instance.
|
||||
db *storm.DB // the boltdb instance.
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "bolt-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init initializes and connects to the boltdb instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.Options == nil {
|
||||
h.config.Options = &bbolt.Options{
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
if h.config.Path == "" {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
var err error
|
||||
h.db, err = storm.Open(h.config.Path, storm.BoltOptions(0600, h.config.Options), storm.Codec(sgob.Codec))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the boltdb instance.
|
||||
func (h *Hook) Stop() error {
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.DeleteStruct(&storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
|
||||
Msg("failed to delete client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.DeleteStruct(&storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", retainedKey(pk.TopicName)).
|
||||
Msg("failed to delete retained publish")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save retained publish data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save qos inflight data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", inflightKey(cl, pk)).
|
||||
Msg("failed to delete inflight data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Interface("data", in).
|
||||
Msg("failed to save $SYS data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.ClientKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.SubscriptionKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.RetainedKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.One("ID", storage.SysInfoKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
|
@ -1,717 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
err := os.Remove(path)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "bolt-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestInitBadPath(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Path: "..",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.db.Save(&storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.db.One("ID", clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.One("ID", retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.One("ID", retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.One("ID", retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.db.Save(m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.One("ID", m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.db.One("ID", m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.One("ID", inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved to bolt
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.db.One("ID", inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.db.One("ID", storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.db.Save(&storage.Client{ID: "cl1", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Client{ID: "cl2", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Client{ID: "cl3", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.db.Save(&storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "m2", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "m3", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with sys info
|
||||
err = h.db.Save(&storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
|
@ -1,513 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// defaultAddr is the default address to the redis service.
|
||||
const defaultAddr = "localhost:6379"
|
||||
|
||||
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
|
||||
const defaultHPrefix = "mochi-"
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the bolt instance.
|
||||
type Options struct {
|
||||
HPrefix string
|
||||
Options *redis.Options
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using Redis as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for connecting to the Redis instance.
|
||||
db *redis.Client // the Redis instance
|
||||
ctx context.Context // a context for the connection
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "redis-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// hKey returns a hash set key with a unique prefix.
|
||||
func (h *Hook) hKey(s string) string {
|
||||
return h.config.HPrefix + s
|
||||
}
|
||||
|
||||
// Init initializes and connects to the redis service.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
h.ctx = context.Background()
|
||||
|
||||
if config == nil {
|
||||
config = &Options{
|
||||
Options: &redis.Options{
|
||||
Addr: defaultAddr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.HPrefix == "" {
|
||||
h.config.HPrefix = defaultHPrefix
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("address", h.config.Options.Addr).
|
||||
Str("username", h.config.Options.Username).
|
||||
Int("password-len", len(h.config.Options.Password)).
|
||||
Int("db", h.config.Options.DB).
|
||||
Msg("connecting to redis service")
|
||||
|
||||
h.db = redis.NewClient(h.config.Options)
|
||||
_, err := h.db.Ping(context.Background()).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ping service: %w", err)
|
||||
}
|
||||
|
||||
h.Log.Info().Msg("connected to redis service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the redis connection.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Info().Msg("disconnecting from redis service")
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Client
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Subscription
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = v.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
|
@ -1,789 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func newHook(t *testing.T, addr string) *Hook {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: addr,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func teardown(t *testing.T, h *Hook) {
|
||||
if h.db != nil {
|
||||
err := h.db.FlushAll(h.ctx).Err()
|
||||
require.NoError(t, err)
|
||||
h.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, "cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, "a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, "cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
require.Equal(t, "redis-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestHKey(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.SetOpts(&logger, nil)
|
||||
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
s.StartAddr(defaultAddr)
|
||||
defer s.Close()
|
||||
|
||||
h := newHook(t, defaultAddr)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h)
|
||||
|
||||
require.Equal(t, defaultHPrefix, h.config.HPrefix)
|
||||
require.Equal(t, defaultAddr, h.config.Options.Addr)
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitBadAddr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: "abc:123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r2.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
|
||||
h.db = nil
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientKey, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, redis.Nil, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
|
||||
r := new(storage.Subscription)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved to bolt
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with clients
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with messages
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with messages
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with sys info
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
|
||||
&storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
|
@ -1,164 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
)
|
||||
|
||||
const (
|
||||
SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store
|
||||
SysInfoKey = "SYS" // unique key to denote server system information in a store
|
||||
RetainedKey = "RET" // unique key to denote retained messages in a store
|
||||
InflightKey = "IFM" // unique key to denote inflight messages in a store
|
||||
ClientKey = "CL" // unique key to denote clients in a store
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
|
||||
ErrDBFileNotOpen = errors.New("db file not open")
|
||||
)
|
||||
|
||||
// Client is a storable representation of an mqtt client.
|
||||
type Client struct {
|
||||
Will ClientWill `json:"will"` // will topic and payload data if applicable
|
||||
Properties ClientProperties `json:"properties"` // the connect properties for the client
|
||||
Username []byte `json:"username"` // the username of the client
|
||||
ID string `json:"id" storm:"id"` // the client id / storage key
|
||||
T string `json:"t"` // the data type (client)
|
||||
Remote string `json:"remote"` // the remote address of the client
|
||||
Listener string `json:"listener"` // the listener the client connected on
|
||||
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
|
||||
Clean bool `json:"clean"` // if the client requested a clean start/session
|
||||
}
|
||||
|
||||
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
|
||||
type ClientProperties struct {
|
||||
AuthenticationData []byte `json:"authenticationData"`
|
||||
User []packets.UserProperty `json:"user"`
|
||||
AuthenticationMethod string `json:"authenticationMethod"`
|
||||
SessionExpiryInterval uint32 `json:"sessionExpiryInterval"`
|
||||
MaximumPacketSize uint32 `json:"maximumPacketSize"`
|
||||
ReceiveMaximum uint16 `json:"receiveMaximum"`
|
||||
TopicAliasMaximum uint16 `json:"topicAliasMaximum"`
|
||||
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag"`
|
||||
RequestProblemInfo byte `json:"requestProblemInfo"`
|
||||
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag"`
|
||||
RequestResponseInfo byte `json:"requestResponseInfo"`
|
||||
}
|
||||
|
||||
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
|
||||
type ClientWill struct {
|
||||
Payload []byte `json:"payload"`
|
||||
User []packets.UserProperty `json:"user"`
|
||||
TopicName string `json:"topicName"`
|
||||
Flag uint32 `json:"flag"`
|
||||
WillDelayInterval uint32 `json:"willDelayInterval"`
|
||||
Qos byte `json:"qos"`
|
||||
Retain bool `json:"retain"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Client) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Client) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// Message is a storable representation of an MQTT message (specifically publish).
|
||||
type Message struct {
|
||||
Properties MessageProperties `json:"properties"` // -
|
||||
Payload []byte `json:"payload"` // the message payload (if retained)
|
||||
T string `json:"t"` // the data type
|
||||
ID string `json:"id" storm:"id"` // the storage key
|
||||
Origin string `json:"origin"` // the id of the client who sent the message
|
||||
TopicName string `json:"topic_name"` // the topic the message was sent to (if retained)
|
||||
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
|
||||
Created int64 `json:"created"` // the time the message was created in unixtime
|
||||
Sent int64 `json:"sent"` // the last time the message was sent (for retries) in unixtime (if inflight)
|
||||
PacketID uint16 `json:"packet_id"` // the unique id of the packet (if inflight)
|
||||
}
|
||||
|
||||
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
|
||||
type MessageProperties struct {
|
||||
CorrelationData []byte `json:"correlationData"`
|
||||
SubscriptionIdentifier []int `json:"subscriptionIdentifier"`
|
||||
User []packets.UserProperty `json:"user"`
|
||||
ContentType string `json:"contentType"`
|
||||
ResponseTopic string `json:"responseTopic"`
|
||||
MessageExpiryInterval uint32 `json:"messageExpiry"`
|
||||
TopicAlias uint16 `json:"topicAlias"`
|
||||
PayloadFormat byte `json:"payloadFormat"`
|
||||
PayloadFormatFlag bool `json:"payloadFormatFlag"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Message) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Message) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an mqtt subscription.
|
||||
type Subscription struct {
|
||||
T string `json:"t"`
|
||||
ID string `json:"id" storm:"id"`
|
||||
Client string `json:"client"`
|
||||
Filter string `json:"filter"`
|
||||
Identifier int `json:"identifier"`
|
||||
RetainHandling byte `json:"retain_handling"`
|
||||
Qos byte `json:"qos"`
|
||||
RetainAsPublished bool `json:"retain_as_pub"`
|
||||
NoLocal bool `json:"no_local"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Subscription) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Subscription) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// SystemInfo is a storable representation of the system information values.
|
||||
type SystemInfo struct {
|
||||
system.Info // embed the system info struct
|
||||
T string `json:"t"` // the data type
|
||||
ID string `json:"id" storm:"id"` // the storage key
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
|
@ -1,196 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
clientStruct = Client{
|
||||
ID: "test",
|
||||
T: "client",
|
||||
Remote: "remote",
|
||||
Listener: "listener",
|
||||
Username: []byte("mochi"),
|
||||
Clean: true,
|
||||
Properties: ClientProperties{
|
||||
SessionExpiryInterval: 2,
|
||||
SessionExpiryIntervalFlag: true,
|
||||
AuthenticationMethod: "a",
|
||||
AuthenticationData: []byte("test"),
|
||||
RequestProblemInfo: 1,
|
||||
RequestProblemInfoFlag: true,
|
||||
RequestResponseInfo: 1,
|
||||
ReceiveMaximum: 128,
|
||||
TopicAliasMaximum: 256,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k", Val: "v"},
|
||||
},
|
||||
MaximumPacketSize: 120,
|
||||
},
|
||||
Will: ClientWill{
|
||||
Qos: 1,
|
||||
Payload: []byte("abc"),
|
||||
TopicName: "a/b/c",
|
||||
Flag: 1,
|
||||
Retain: true,
|
||||
WillDelayInterval: 2,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k2", Val: "v2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
|
||||
|
||||
messageStruct = Message{
|
||||
T: "message",
|
||||
Payload: []byte("payload"),
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Remaining: 2,
|
||||
Type: 3,
|
||||
Qos: 1,
|
||||
Dup: true,
|
||||
Retain: true,
|
||||
},
|
||||
ID: "id",
|
||||
Origin: "mochi",
|
||||
TopicName: "topic",
|
||||
Properties: MessageProperties{
|
||||
PayloadFormat: 1,
|
||||
PayloadFormatFlag: true,
|
||||
MessageExpiryInterval: 20,
|
||||
ContentType: "type",
|
||||
ResponseTopic: "a/b/r",
|
||||
CorrelationData: []byte("r"),
|
||||
SubscriptionIdentifier: []int{1},
|
||||
TopicAlias: 2,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k2", Val: "v2"},
|
||||
},
|
||||
},
|
||||
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
|
||||
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
|
||||
PacketID: 100,
|
||||
}
|
||||
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
|
||||
|
||||
subscriptionStruct = Subscription{
|
||||
T: "subscription",
|
||||
ID: "id",
|
||||
Client: "mochi",
|
||||
Filter: "a/b/c",
|
||||
Qos: 1,
|
||||
}
|
||||
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","identifier":0,"retain_handling":0,"qos":1,"retain_as_pub":false,"no_local":false}`)
|
||||
|
||||
sysInfoStruct = SystemInfo{
|
||||
T: "info",
|
||||
ID: "id",
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
Started: 1,
|
||||
Uptime: 2,
|
||||
BytesReceived: 3,
|
||||
BytesSent: 4,
|
||||
ClientsConnected: 5,
|
||||
ClientsMaximum: 7,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
PacketsReceived: 12,
|
||||
PacketsSent: 13,
|
||||
Retained: 15,
|
||||
Inflight: 16,
|
||||
InflightDropped: 17,
|
||||
},
|
||||
}
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
)
|
||||
|
||||
func TestClientMarshalBinary(t *testing.T) {
|
||||
data, err := clientStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientJSON, data)
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinary(t *testing.T) {
|
||||
d := clientStruct
|
||||
err := d.UnmarshalBinary(clientJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientStruct, d)
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Client{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Client{}, d)
|
||||
}
|
||||
|
||||
func TestMessageMarshalBinary(t *testing.T) {
|
||||
data, err := messageStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageJSON, data)
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinary(t *testing.T) {
|
||||
d := messageStruct
|
||||
err := d.UnmarshalBinary(messageJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageStruct, d)
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Message{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Message{}, d)
|
||||
}
|
||||
|
||||
func TestSubscriptionMarshalBinary(t *testing.T) {
|
||||
data, err := subscriptionStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionJSON, data)
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinary(t *testing.T) {
|
||||
d := subscriptionStruct
|
||||
err := d.UnmarshalBinary(subscriptionJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionStruct, d)
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Subscription{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Subscription{}, d)
|
||||
}
|
||||
|
||||
func TestSysInfoMarshalBinary(t *testing.T) {
|
||||
data, err := sysInfoStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoJSON, data)
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinary(t *testing.T) {
|
||||
d := sysInfoStruct
|
||||
err := d.UnmarshalBinary(sysInfoJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoStruct, d)
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := SystemInfo{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SystemInfo{}, d)
|
||||
}
|
||||
|
|
@ -1,634 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type modifiedHookBase struct {
|
||||
HookBase
|
||||
err error
|
||||
fail bool
|
||||
failAt int
|
||||
}
|
||||
|
||||
var errTestHook = errors.New("error")
|
||||
|
||||
func (h *modifiedHookBase) ID() string {
|
||||
return "modified"
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Init(config any) error {
|
||||
if config != nil {
|
||||
return errTestHook
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Provides(b byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Stop() error {
|
||||
if h.fail {
|
||||
return errTestHook
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
if h.fail {
|
||||
return will, errTestHook
|
||||
}
|
||||
|
||||
return will, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
|
||||
if h.fail || h.failAt == 1 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Client{
|
||||
{ID: "cl1"},
|
||||
{ID: "cl2"},
|
||||
{ID: "cl3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.fail || h.failAt == 2 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Subscription{
|
||||
{ID: "sub1"},
|
||||
{ID: "sub2"},
|
||||
{ID: "sub3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.fail || h.failAt == 3 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Message{
|
||||
{ID: "r1"},
|
||||
{ID: "r2"},
|
||||
{ID: "r3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.fail || h.failAt == 4 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Message{
|
||||
{ID: "i1"},
|
||||
{ID: "i2"},
|
||||
{ID: "i3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.fail || h.failAt == 5 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return storage.SystemInfo{
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type providesCheckHook struct {
|
||||
HookBase
|
||||
}
|
||||
|
||||
func (h *providesCheckHook) Provides(b byte) bool {
|
||||
return b == OnConnect
|
||||
}
|
||||
|
||||
func TestHooksProvides(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(providesCheckHook), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.Provides(OnConnect, OnDisconnect))
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHooksAddLenGetAll(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(2), h.Len())
|
||||
|
||||
all := h.GetAll()
|
||||
require.Equal(t, "base", all[0].ID())
|
||||
require.Equal(t, "modified", all[1].ID())
|
||||
}
|
||||
|
||||
func TestHooksAddInitFailure(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), map[string]any{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
|
||||
}
|
||||
|
||||
func TestHooksStop(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(1), h.Len())
|
||||
|
||||
h.Stop()
|
||||
}
|
||||
|
||||
// coverage: also cover some empty functions
|
||||
func TestHooksNonReturns(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
cl := new(Client)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
|
||||
// on first iteration, check without hook methods
|
||||
h.OnStarted()
|
||||
h.OnStopped()
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
h.OnConnect(cl, packets.Packet{})
|
||||
h.OnSessionEstablished(cl, packets.Packet{})
|
||||
h.OnDisconnect(cl, nil, false)
|
||||
h.OnPacketSent(cl, packets.Packet{}, []byte{})
|
||||
h.OnPacketProcessed(cl, packets.Packet{}, nil)
|
||||
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
|
||||
h.OnUnsubscribed(cl, packets.Packet{})
|
||||
h.OnPublished(cl, packets.Packet{})
|
||||
h.OnPublishDropped(cl, packets.Packet{})
|
||||
h.OnRetainMessage(cl, packets.Packet{}, 0)
|
||||
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
|
||||
h.OnQosComplete(cl, packets.Packet{})
|
||||
h.OnQosDropped(cl, packets.Packet{})
|
||||
h.OnWillSent(cl, packets.Packet{})
|
||||
h.OnClientExpired(cl)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
|
||||
// on second iteration, check added hook methods
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHooksOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
|
||||
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, ok)
|
||||
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestHooksOnACLCheck(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
|
||||
ok := h.OnACLCheck(new(Client), "a/b/c", true)
|
||||
require.False(t, ok)
|
||||
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = h.OnACLCheck(new(Client), "a/b/c", true)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestHooksOnSubscribe(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pki := packets.Packet{
|
||||
Filters: packets.Subscriptions{
|
||||
{Filter: "a/b/c", Qos: 1},
|
||||
},
|
||||
}
|
||||
pk := h.OnSubscribe(new(Client), pki)
|
||||
require.EqualValues(t, pk, pki)
|
||||
}
|
||||
|
||||
func TestHooksOnSelectSubscribers(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
subs := &Subscribers{
|
||||
Subscriptions: map[string]packets.Subscription{
|
||||
"cl1": {Filter: "a/b/c"},
|
||||
},
|
||||
}
|
||||
|
||||
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
|
||||
require.EqualValues(t, subs, subs2)
|
||||
}
|
||||
|
||||
func TestHooksOnUnsubscribe(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pki := packets.Packet{
|
||||
Filters: packets.Subscriptions{
|
||||
{Filter: "a/b/c", Qos: 1},
|
||||
},
|
||||
}
|
||||
|
||||
pk := h.OnUnsubscribe(new(Client), pki)
|
||||
require.EqualValues(t, pk, pki)
|
||||
}
|
||||
|
||||
func TestHooksOnPublish(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
hook.err = packets.ErrRejectPacket
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrRejectPacket)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketRead(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
hook.err = packets.ErrRejectPacket
|
||||
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrRejectPacket)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnAuthPacket(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
hook.fail = true
|
||||
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketEncode(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnLWT(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
|
||||
// coverage: fail lwt
|
||||
hook.fail = true
|
||||
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
}
|
||||
|
||||
func TestHooksStoredClients(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredClients()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredSubscriptions()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredRetainedMessages()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredInflightMessages()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredSysInfo(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", v.Info.Version)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", v.Info.Version)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredSysInfo()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "", v.Info.Version)
|
||||
}
|
||||
|
||||
func TestHookBaseID(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Equal(t, "base", h.ID())
|
||||
}
|
||||
|
||||
func TestHookBaseProvidesNone(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.False(t, h.Provides(OnConnect))
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHookBaseInit(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Nil(t, h.Init(nil))
|
||||
}
|
||||
|
||||
func TestHookBaseSetOpts(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
h.SetOpts(&logger, new(HookOptions))
|
||||
require.NotNil(t, h.Log)
|
||||
require.NotNil(t, h.Opts)
|
||||
}
|
||||
|
||||
func TestHookBaseClose(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Nil(t, h.Stop())
|
||||
}
|
||||
|
||||
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, v)
|
||||
}
|
||||
func TestHookBaseOnACLCheck(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnACLCheck(new(Client), "topic", true)
|
||||
require.False(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPublish(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPacketRead(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnAuthPacket(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnLWT(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredClients(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredSubscriptions(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredInflightMessages(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredRetainedMessages(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoreSysInfo(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", v.Version)
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 42 KiB |
|
|
@ -1,156 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
// Inflight is a map of InflightMessage keyed on packet id.
|
||||
type Inflight struct {
|
||||
sync.RWMutex
|
||||
internal map[uint16]packets.Packet // internal contains the inflight packets
|
||||
receiveQuota int32 // remaining inbound qos quota for flow control
|
||||
sendQuota int32 // remaining outbound qos quota for flow control
|
||||
maximumReceiveQuota int32 // maximum allowed receive quota
|
||||
maximumSendQuota int32 // maximum allowed send quota
|
||||
}
|
||||
|
||||
// NewInflights returns a new instance of an Inflight packets map.
|
||||
func NewInflights() *Inflight {
|
||||
return &Inflight{
|
||||
internal: map[uint16]packets.Packet{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds or updates an inflight packet by packet id.
|
||||
func (i *Inflight) Set(m packets.Packet) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
_, ok := i.internal[m.PacketID]
|
||||
i.internal[m.PacketID] = m
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Get returns an inflight packet by packet id.
|
||||
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
if m, ok := i.internal[id]; ok {
|
||||
return m, true
|
||||
}
|
||||
|
||||
return packets.Packet{}, false
|
||||
}
|
||||
|
||||
// Len returns the size of the inflight messages map.
|
||||
func (i *Inflight) Len() int {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
return len(i.internal)
|
||||
}
|
||||
|
||||
// Clone returns a new instance of Inflight with the same message data.
|
||||
// This is used when transferring inflights from a taken-over session.
|
||||
func (i *Inflight) Clone() *Inflight {
|
||||
c := NewInflights()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
for k, v := range i.internal {
|
||||
c.internal[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetAll returns all the inflight messages.
|
||||
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
m := []packets.Packet{}
|
||||
for _, v := range i.internal {
|
||||
if !immediate || (immediate && v.Expiry < 0) {
|
||||
m = append(m, v)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(m, func(i, j int) bool {
|
||||
return uint16(m[i].Created) < uint16(m[j].Created)
|
||||
})
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
|
||||
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
|
||||
// is free to continue sending.
|
||||
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
m := i.GetAll(true)
|
||||
if len(m) > 0 {
|
||||
return m[0], true
|
||||
}
|
||||
|
||||
return packets.Packet{}, false
|
||||
}
|
||||
|
||||
// Delete removes an in-flight message from the map. Returns true if the message existed.
|
||||
func (i *Inflight) Delete(id uint16) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
_, ok := i.internal[id]
|
||||
delete(i.internal, id)
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// TakeRecieveQuota reduces the receive quota by 1.
|
||||
func (i *Inflight) DecreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) > 0 {
|
||||
atomic.AddInt32(&i.receiveQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// TakeRecieveQuota increases the receive quota by 1.
|
||||
func (i *Inflight) IncreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
|
||||
atomic.AddInt32(&i.receiveQuota, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
|
||||
func (i *Inflight) ResetReceiveQuota(n int32) {
|
||||
atomic.StoreInt32(&i.receiveQuota, n)
|
||||
atomic.StoreInt32(&i.maximumReceiveQuota, n)
|
||||
}
|
||||
|
||||
// DecreaseSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) DecreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) > 0 {
|
||||
atomic.AddInt32(&i.sendQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// IncreaseSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) IncreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
|
||||
atomic.AddInt32(&i.sendQuota, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetSendQuota resets the send quota to the maximum allowed value.
|
||||
func (i *Inflight) ResetSendQuota(n int32) {
|
||||
atomic.StoreInt32(&i.sendQuota, n)
|
||||
atomic.StoreInt32(&i.maximumSendQuota, n)
|
||||
}
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInflightSet(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.True(t, r)
|
||||
require.NotNil(t, cl.State.Inflight.internal[1])
|
||||
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
|
||||
|
||||
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.False(t, r)
|
||||
}
|
||||
|
||||
func TestInflightGet(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
msg, ok := cl.State.Inflight.Get(2)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, 0, msg.PacketID)
|
||||
}
|
||||
|
||||
func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
|
||||
|
||||
require.Equal(t, []packets.Packet{
|
||||
{PacketID: 1, Created: 1},
|
||||
{PacketID: 2, Created: 2},
|
||||
{PacketID: 3, Created: 3, Expiry: -1},
|
||||
{PacketID: 4, Created: 4, Expiry: -1},
|
||||
{PacketID: 5, Created: 5},
|
||||
}, cl.State.Inflight.GetAll(false))
|
||||
|
||||
require.Equal(t, []packets.Packet{
|
||||
{PacketID: 3, Created: 3, Expiry: -1},
|
||||
{PacketID: 4, Created: 4, Expiry: -1},
|
||||
}, cl.State.Inflight.GetAll(true))
|
||||
}
|
||||
|
||||
func TestInflightLen(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightClone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
cloned := cl.State.Inflight.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.NotSame(t, cloned, cl.State.Inflight)
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
|
||||
require.NotNil(t, cl.State.Inflight.internal[3])
|
||||
|
||||
r := cl.State.Inflight.Delete(3)
|
||||
require.True(t, r)
|
||||
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
|
||||
|
||||
_, ok := cl.State.Inflight.Get(3)
|
||||
require.False(t, ok)
|
||||
|
||||
r = cl.State.Inflight.Delete(3)
|
||||
require.False(t, r)
|
||||
}
|
||||
|
||||
func TestResetReceiveQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
i.ResetReceiveQuota(6)
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
|
||||
func TestReceiveQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
i.receiveQuota = 4
|
||||
i.maximumReceiveQuota = 5
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Return 1
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Reset to max 1
|
||||
i.ResetReceiveQuota(1)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Take 1
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
|
||||
func TestResetSendQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
i.ResetSendQuota(6)
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestSendQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
i.sendQuota = 4
|
||||
i.maximumSendQuota = 5
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Return 1
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Reset to max 1
|
||||
i.ResetSendQuota(1)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Take 1
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestNextImmediate(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
|
||||
|
||||
pk, ok := cl.State.Inflight.NextImmediate()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
|
||||
|
||||
r := cl.State.Inflight.Delete(3)
|
||||
require.True(t, r)
|
||||
|
||||
pk, ok = cl.State.Inflight.NextImmediate()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
|
||||
|
||||
r = cl.State.Inflight.Delete(4)
|
||||
require.True(t, r)
|
||||
|
||||
_, ok = cl.State.Inflight.NextImmediate()
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
log *zerolog.Logger // server logger
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
return &HTTPStats{
|
||||
id: id,
|
||||
address: address,
|
||||
sysInfo: sysInfo,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *HTTPStats) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *HTTPStats) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *HTTPStats) Protocol() string {
|
||||
if l.listen != nil && l.listen.TLSConfig != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPStats) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.jsonHandler)
|
||||
l.listen = &http.Server{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPStats) Serve(establish EstablishFn) {
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
|
||||
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
info := *l.sysInfo.Clone()
|
||||
|
||||
out, err := json.MarshalIndent(info, "", "\t")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
}
|
||||
|
||||
w.Write(out)
|
||||
}
|
||||
|
|
@ -1,127 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPStats(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestHTTPStatsID(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestHTTPStatsAddress(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestHTTPStatsProtocol(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, "http", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsTLSProtocol(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, nil)
|
||||
|
||||
l.Init(nil)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsInit(t *testing.T) {
|
||||
sysInfo := new(system.Info)
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.sysInfo)
|
||||
require.Equal(t, sysInfo, l.sysInfo)
|
||||
require.NotNil(t, l.listen)
|
||||
require.Equal(t, testAddr, l.listen.Addr)
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeAndClose(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// get body from stats address
|
||||
resp, err := http.Get("http://localhost" + testAddr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// decode body from json and check data
|
||||
v := new(system.Info)
|
||||
err = json.Unmarshal(body, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", v.Version)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Get("http://localhost" + testAddr)
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
l := NewHTTPStats("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, sysInfo)
|
||||
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
|
|
@ -1,135 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener.
|
||||
// See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// EstablishFn is a callback function for establishing new clients.
|
||||
type EstablishFn func(id string, c net.Conn) error
|
||||
|
||||
// CloseFunc is a callback function for closing all listener clients.
|
||||
type CloseFn func(id string)
|
||||
|
||||
// Listener is an interface for network listeners. A network listener listens
|
||||
// for incoming client connections and adds them to the server.
|
||||
type Listener interface {
|
||||
Init(*zerolog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
}
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new instance of Listeners.
|
||||
func New() *Listeners {
|
||||
return &Listeners{
|
||||
internal: map[string]Listener{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new listener to the listeners map, keyed on id.
|
||||
func (l *Listeners) Add(val Listener) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.internal[val.ID()] = val
|
||||
}
|
||||
|
||||
// Get returns the value of a listener if it exists.
|
||||
func (l *Listeners) Get(id string) (Listener, bool) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
val, ok := l.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the listeners map.
|
||||
func (l *Listeners) Len() int {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
return len(l.internal)
|
||||
}
|
||||
|
||||
// Delete removes a listener from the internal map.
|
||||
func (l *Listeners) Delete(id string) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
delete(l.internal, id)
|
||||
}
|
||||
|
||||
// Serve starts a listener serving from the internal map.
|
||||
func (l *Listeners) Serve(id string, establisher EstablishFn) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
listener := l.internal[id]
|
||||
|
||||
go func(e EstablishFn) {
|
||||
defer l.wg.Done()
|
||||
l.wg.Add(1)
|
||||
listener.Serve(e)
|
||||
}(establisher)
|
||||
}
|
||||
|
||||
// ServeAll starts all listeners serving from the internal map.
|
||||
func (l *Listeners) ServeAll(establisher EstablishFn) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Serve(id, establisher)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops a listener from the internal map.
|
||||
func (l *Listeners) Close(id string, closer CloseFn) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
if listener, ok := l.internal[id]; ok {
|
||||
listener.Close(closer)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseAll iterates and closes all registered listeners.
|
||||
func (l *Listeners) CloseAll(closer CloseFn) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Close(id, closer)
|
||||
}
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAddr = ":22222"
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New()
|
||||
require.NotNil(t, l.internal)
|
||||
}
|
||||
|
||||
func TestAddListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
}
|
||||
|
||||
func TestGetListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
|
||||
g, ok := l.Get("t1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, g.ID(), "t1")
|
||||
}
|
||||
|
||||
func TestLenListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
require.Equal(t, 2, l.Len())
|
||||
}
|
||||
|
||||
func TestDeleteListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
l.Delete("t1")
|
||||
_, ok := l.Get("t1")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, l.internal["t1"])
|
||||
}
|
||||
|
||||
func TestServeListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
require.False(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func TestServeAllListeners(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
l.Add(NewMockListener("t3", testAddr))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
l.Close("t2", MockCloser)
|
||||
l.Close("t3", MockCloser)
|
||||
|
||||
require.False(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.False(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.False(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func TestCloseListener(t *testing.T) {
|
||||
l := New()
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
l.Add(mocked)
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close("t1", func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.True(t, closed)
|
||||
}
|
||||
|
||||
func TestCloseAllListeners(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
l.Add(NewMockListener("t3", testAddr))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
closed := make(map[string]bool)
|
||||
l.CloseAll(func(id string) {
|
||||
closed[id] = true
|
||||
})
|
||||
require.Contains(t, closed, "t1")
|
||||
require.Contains(t, closed, "t2")
|
||||
require.Contains(t, closed, "t3")
|
||||
require.True(t, closed["t1"])
|
||||
require.True(t, closed["t2"])
|
||||
require.True(t, closed["t3"])
|
||||
}
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// MockEstablisher is a function signature which can be used in testing.
|
||||
func MockEstablisher(id string, c net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockCloser is a function signature which can be used in testing.
|
||||
func MockCloser(id string) {}
|
||||
|
||||
// MockListener is a mock listener for establishing client connections.
|
||||
type MockListener struct {
|
||||
sync.RWMutex
|
||||
id string // the id of the listener
|
||||
address string // the network address the listener binds to
|
||||
Config *Config // configuration for the listener
|
||||
done chan bool // indicate the listener is done
|
||||
Serving bool // indicate the listener is serving
|
||||
Listening bool // indiciate the listener is listening
|
||||
ErrListen bool // throw an error on listen
|
||||
}
|
||||
|
||||
// NewMockListener returns a new instance of MockListener.
|
||||
func NewMockListener(id, address string) *MockListener {
|
||||
return &MockListener{
|
||||
id: id,
|
||||
address: address,
|
||||
done: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve serves the mock listener.
|
||||
func (l *MockListener) Serve(establisher EstablishFn) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *MockListener) Init(log *zerolog.Logger) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
}
|
||||
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Listening = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// ID returns the id of the mock listener.
|
||||
func (l *MockListener) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *MockListener) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *MockListener) Protocol() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
// Close closes the mock listener.
|
||||
func (l *MockListener) Close(closer CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Serving = false
|
||||
closer(l.id)
|
||||
close(l.done)
|
||||
}
|
||||
|
||||
// IsServing indicates whether the mock listener is serving.
|
||||
func (l *MockListener) IsServing() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Serving
|
||||
}
|
||||
|
||||
// IsListening indicates whether the mock listener is listening.
|
||||
func (l *MockListener) IsListening() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Listening
|
||||
}
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMockEstablisher(t *testing.T) {
|
||||
_, w := net.Pipe()
|
||||
err := MockEstablisher("t1", w)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
}
|
||||
|
||||
func TestNewMockListener(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, testAddr, mocked.address)
|
||||
}
|
||||
func TestMockListenerID(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.ID())
|
||||
}
|
||||
|
||||
func TestMockListenerAddress(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, testAddr, mocked.Address())
|
||||
}
|
||||
func TestMockListenerProtocol(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "mock", mocked.Protocol())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsListening(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsServing(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
}
|
||||
|
||||
func TestNewMockListenerInit(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, testAddr, mocked.address)
|
||||
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
err := mocked.Init(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerInitFailure(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
mocked.ErrListen = true
|
||||
err := mocked.Init(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMockListenerServe(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
mocked.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
|
||||
require.Equal(t, true, mocked.IsServing())
|
||||
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
|
||||
mocked.Init(nil)
|
||||
}
|
||||
|
||||
func TestMockListenerClose(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Net is a listener for establishing client connections on basic TCP protocol.
|
||||
type Net struct { // [MQTT-4.2.0-1]
|
||||
mu sync.Mutex
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
|
||||
func NewNet(id string, listener net.Listener) *Net {
|
||||
return &Net{
|
||||
id: id,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Net) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Net) Address() string {
|
||||
return l.listener.Addr().String()
|
||||
}
|
||||
|
||||
// Protocol returns the network of the listener.
|
||||
func (l *Net) Protocol() string {
|
||||
return l.listener.Addr().Network()
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Net) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *Net) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Net) Close(closeClients CloseFn) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listener != nil {
|
||||
err := l.listener.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,105 +0,0 @@
|
|||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewNet(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.id)
|
||||
}
|
||||
|
||||
func TestNetID(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestNetAddress(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, n.Addr().String(), l.Address())
|
||||
}
|
||||
|
||||
func TestNetProtocol(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestNetInit(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetServeAndClose(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestNetEstablishThenEnd(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", n.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// TCP is a listener for establishing client connections on basic TCP protocol.
|
||||
type TCP struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config *Config // configuration values for the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
func NewTCP(id, address string, config *Config) *TCP {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
|
||||
return &TCP{
|
||||
id: id,
|
||||
address: address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *TCP) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *TCP) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *TCP) Protocol() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *TCP) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen("tcp", l.address)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *TCP) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTCP(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestTCPID(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestTCPAddress(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestTCPProtocol(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPProtocolTLS(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
l.Init(&logger)
|
||||
defer l.listen.Close()
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPInit(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewTCP("t2", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err = l2.Init(&logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l2.config.TLSConfig)
|
||||
}
|
||||
|
||||
func TestTCPServeAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||
func NewUnixSock(id, address string) *UnixSock {
|
||||
return &UnixSock{
|
||||
id: id,
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *UnixSock) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *UnixSock) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *UnixSock) Protocol() string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
_ = os.Remove(l.address)
|
||||
l.listen, err = net.Listen("unix", l.address)
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnixAddr = "mochi.sock"
|
||||
|
||||
func TestNewUnixSock(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testUnixAddr, l.address)
|
||||
}
|
||||
|
||||
func TestUnixSockID(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestUnixSockAddress(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, testUnixAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestUnixSockProtocol(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "unix", l.Protocol())
|
||||
}
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewUnixSock("t2", testUnixAddr)
|
||||
err = l2.Init(&logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
|
|
@ -1,178 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidMessage indicates that a message payload was not valid.
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
)
|
||||
|
||||
// Websocket is a listener for establishing websocket connections.
|
||||
type Websocket struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // an http server for serving websocket connections
|
||||
log *zerolog.Logger // server logger
|
||||
establish EstablishFn // the server's establish connection handler
|
||||
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
|
||||
func NewWebsocket(id, address string, config *Config) *Websocket {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
|
||||
return &Websocket{
|
||||
id: id,
|
||||
address: address,
|
||||
config: config,
|
||||
upgrader: &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Websocket) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Websocket) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *Websocket) Protocol() string {
|
||||
if l.config.TLSConfig != nil {
|
||||
return "wss"
|
||||
}
|
||||
|
||||
return "ws"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Websocket) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.handler)
|
||||
l.listen = &http.Server{
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.config.TLSConfig,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handler upgrades and handles an incoming websocket connection.
|
||||
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := l.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts waiting for new Websocket connections, and calls the connection
|
||||
// establishment callback for any received.
|
||||
func (l *Websocket) Serve(establish EstablishFn) {
|
||||
l.establish = establish
|
||||
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Websocket) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// wsConn is a websocket connection which satisfies the net.Conn interface.
|
||||
type wsConn struct {
|
||||
net.Conn
|
||||
c *websocket.Conn
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var n, br int
|
||||
for {
|
||||
br, err = r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close signals the underlying websocket conn to close.
|
||||
func (ws *wsConn) Close() error {
|
||||
return ws.Conn.Close()
|
||||
}
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewWebsocket(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestWebsocketID(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestWebsocketAddress(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocol(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, "ws", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocoTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
require.Equal(t, "wss", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsockeInit(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Nil(t, l.listen)
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketServeAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestWebsocketUpgrade(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
|
||||
e := make(chan bool)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
e <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(l.handler))
|
||||
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, <-e)
|
||||
|
||||
s.Close()
|
||||
ws.Close()
|
||||
}
|
||||
|
|
@ -1,172 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// bytesToString provides a zero-alloc no-copy byte to string conversion.
|
||||
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
|
||||
func bytesToString(bs []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&bs))
|
||||
}
|
||||
|
||||
// decodeUint16 extracts the value of two bytes from a byte array.
|
||||
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
|
||||
if len(buf) < offset+2 {
|
||||
return 0, 0, ErrMalformedOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
|
||||
}
|
||||
|
||||
// decodeUint32 extracts the value of four bytes from a byte array.
|
||||
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
|
||||
if len(buf) < offset+4 {
|
||||
return 0, 0, ErrMalformedOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
|
||||
}
|
||||
|
||||
// decodeString extracts a string from a byte array, beginning at an offset.
|
||||
func decodeString(buf []byte, offset int) (string, int, error) {
|
||||
b, n, err := decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
|
||||
return "", 0, ErrMalformedInvalidUTF8
|
||||
}
|
||||
|
||||
return bytesToString(b), n, nil
|
||||
}
|
||||
|
||||
// validUTF8 checks if the byte array contains valid UTF-8 characters.
|
||||
func validUTF8(b []byte) bool {
|
||||
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
|
||||
}
|
||||
|
||||
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
|
||||
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
|
||||
length, next, err := decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return make([]byte, 0), 0, err
|
||||
}
|
||||
|
||||
if next+int(length) > len(buf) {
|
||||
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
|
||||
}
|
||||
|
||||
return buf[next : next+int(length)], next + int(length), nil
|
||||
}
|
||||
|
||||
// decodeByte extracts the value of a byte from a byte array.
|
||||
func decodeByte(buf []byte, offset int) (byte, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return 0, 0, ErrMalformedOffsetByteOutOfRange
|
||||
}
|
||||
return buf[offset], offset + 1, nil
|
||||
}
|
||||
|
||||
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
|
||||
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return false, 0, ErrMalformedOffsetBoolOutOfRange
|
||||
}
|
||||
return 1&buf[offset] > 0, offset + 1, nil
|
||||
}
|
||||
|
||||
// encodeBool returns a byte instead of a bool.
|
||||
func encodeBool(b bool) byte {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
|
||||
func encodeBytes(val []byte) []byte {
|
||||
// In most circumstances the number of bytes being encoded is small.
|
||||
// Setting the cap to a low amount allows us to account for those without
|
||||
// triggering allocation growth on append unless we need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, val...)
|
||||
}
|
||||
|
||||
// encodeUint16 encodes a uint16 value to a byte array.
|
||||
func encodeUint16(val uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeUint32 encodes a uint16 value to a byte array.
|
||||
func encodeUint32(val uint32) []byte {
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeString encodes a string to a byte array.
|
||||
func encodeString(val string) []byte {
|
||||
// Like encodeBytes, we set the cap to a small number to avoid
|
||||
// triggering allocation growth on append unless we absolutely need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, []byte(val)...)
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(b *bytes.Buffer, length int64) {
|
||||
// 1.5.5 Variable Byte Integer encode non-normative
|
||||
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
|
||||
for {
|
||||
eb := byte(length % 128)
|
||||
length /= 128
|
||||
if length > 0 {
|
||||
eb |= 0x80
|
||||
}
|
||||
b.WriteByte(eb)
|
||||
if length == 0 {
|
||||
break // [MQTT-1.5.5-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
|
||||
// see 1.5.5 Variable Byte Integer decode non-normative
|
||||
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
|
||||
var multiplier uint32
|
||||
var value uint32
|
||||
bu = 1
|
||||
for {
|
||||
eb, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, bu, err
|
||||
}
|
||||
|
||||
value |= uint32(eb&127) << multiplier
|
||||
if value > 268435455 {
|
||||
return 0, bu, ErrMalformedVariableByteInteger
|
||||
}
|
||||
|
||||
if (eb & 128) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
multiplier += 7
|
||||
bu++
|
||||
}
|
||||
|
||||
return int(value), bu, nil
|
||||
}
|
||||
|
|
@ -1,422 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
b := []byte{'a', 'b', 'c'}
|
||||
require.Equal(t, "abc", bytesToString(b))
|
||||
}
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: "a/b/c/d",
|
||||
},
|
||||
{
|
||||
offset: 14,
|
||||
rawBytes: []byte{
|
||||
Connect << 4, 17, // Fixed header
|
||||
0, 6, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
|
||||
3, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'h', 'e', 'y', // Client ID "zen"},
|
||||
},
|
||||
result: "hey",
|
||||
},
|
||||
{
|
||||
offset: 2,
|
||||
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
|
||||
result: "1/2/3/4/a/b/c/d/e/^/@/!",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
|
||||
result: "x/y/z",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 5,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 9,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 17,
|
||||
rawBytes: []byte{
|
||||
Connect << 4, 0, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 6, // Will Topic - MSB+LSB
|
||||
'l',
|
||||
},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrMalformedInvalidUTF8,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
|
||||
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "\ufeff", result)
|
||||
}
|
||||
|
||||
func TestDecodeBytes(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result []uint8
|
||||
next int
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
|
||||
result: []byte{0x4d, 0x51, 0x54, 0x54},
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
|
||||
result: []byte{0x4d, 0x51, 0x54, 0x54},
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 0,
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByte(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint8
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
|
||||
result: uint8(0x00),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x04),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x4d),
|
||||
offset: 2,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x51),
|
||||
offset: 3,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 80, 82, 84},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetByteOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint16(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint16
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x07),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x761),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+2, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint32(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint32
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 0, 0, 7, 8},
|
||||
result: uint32(7),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 0, 1, 226, 64, 8},
|
||||
result: uint32(123456),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+4, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByteBool(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result bool
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0x00, 0x00},
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
offset: 5,
|
||||
shouldFail: ErrMalformedOffsetBoolOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, 1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeLength(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{0x78})
|
||||
n, bu, err := DecodeLength(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 120, n)
|
||||
require.Equal(t, 1, bu)
|
||||
|
||||
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
|
||||
n, bu, err = DecodeLength(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 268435455, n)
|
||||
require.Equal(t, 4, bu)
|
||||
}
|
||||
|
||||
func TestDecodeLengthErrors(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
_, _, err := DecodeLength(b)
|
||||
require.Error(t, err)
|
||||
|
||||
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
|
||||
_, _, err = DecodeLength(b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
|
||||
}
|
||||
|
||||
func TestEncodeBool(t *testing.T) {
|
||||
result := encodeBool(true)
|
||||
require.Equal(t, byte(1), result)
|
||||
|
||||
result = encodeBool(false)
|
||||
require.Equal(t, byte(0), result)
|
||||
|
||||
// Check failure.
|
||||
result = encodeBool(false)
|
||||
require.NotEqual(t, byte(1), result)
|
||||
}
|
||||
|
||||
func TestEncodeBytes(t *testing.T) {
|
||||
result := encodeBytes([]byte("testing"))
|
||||
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
|
||||
|
||||
result = encodeBytes([]byte("testing"))
|
||||
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
|
||||
}
|
||||
|
||||
func TestEncodeUint16(t *testing.T) {
|
||||
result := encodeUint16(0)
|
||||
require.Equal(t, []byte{0x00, 0x00}, result)
|
||||
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result)
|
||||
|
||||
result = encodeUint16(math.MaxUint16)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result)
|
||||
}
|
||||
|
||||
func TestEncodeUint32(t *testing.T) {
|
||||
result := encodeUint32(7)
|
||||
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
|
||||
|
||||
result = encodeUint32(32767)
|
||||
require.Equal(t, []byte{0, 0, 127, 255}, result)
|
||||
|
||||
result = encodeUint32(math.MaxUint32)
|
||||
require.Equal(t, []byte{255, 255, 255, 255}, result)
|
||||
}
|
||||
|
||||
func TestEncodeString(t *testing.T) {
|
||||
result := encodeString("testing")
|
||||
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
|
||||
|
||||
result = encodeString("")
|
||||
require.Equal(t, []uint8{0x00, 0x00}, result)
|
||||
|
||||
result = encodeString("a")
|
||||
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
|
||||
|
||||
result = encodeString("b")
|
||||
require.NotEqual(t, []uint8{0x00, 0x00}, result)
|
||||
}
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
b := new(bytes.Buffer)
|
||||
encodeLength(b, 120)
|
||||
require.Equal(t, []byte{0x78}, b.Bytes())
|
||||
|
||||
b = new(bytes.Buffer)
|
||||
encodeLength(b, math.MaxInt64)
|
||||
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestValidUTF8(t *testing.T) {
|
||||
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
|
||||
require.False(t, validUTF8([]byte{0xff, 0xff}))
|
||||
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
|
||||
}
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// Code contains a reason code and reason string for a response.
|
||||
type Code struct {
|
||||
Reason string
|
||||
Code byte
|
||||
}
|
||||
|
||||
// String returns the readable reason for a code.
|
||||
func (c Code) String() string {
|
||||
return c.Reason
|
||||
}
|
||||
|
||||
// Error returns the readable reason for a code.
|
||||
func (c Code) Error() string {
|
||||
return c.Reason
|
||||
}
|
||||
|
||||
var (
|
||||
// QosCodes indicicates the reason codes for each Qos byte.
|
||||
QosCodes = map[byte]Code{
|
||||
0: CodeGrantedQos0,
|
||||
1: CodeGrantedQos1,
|
||||
2: CodeGrantedQos2,
|
||||
}
|
||||
|
||||
CodeSuccess = Code{Code: 0x00, Reason: "success"}
|
||||
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
|
||||
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
|
||||
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
|
||||
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
|
||||
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
|
||||
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
|
||||
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
|
||||
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
|
||||
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
|
||||
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
|
||||
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
|
||||
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
|
||||
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
|
||||
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
|
||||
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
|
||||
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
|
||||
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
|
||||
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
|
||||
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
|
||||
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
|
||||
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
|
||||
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
|
||||
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
|
||||
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
|
||||
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
|
||||
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
|
||||
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
|
||||
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
|
||||
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
|
||||
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
|
||||
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
|
||||
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
|
||||
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
|
||||
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
|
||||
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
|
||||
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
|
||||
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
|
||||
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
|
||||
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
|
||||
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
|
||||
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
|
||||
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
|
||||
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
|
||||
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
|
||||
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
|
||||
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
|
||||
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
|
||||
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
|
||||
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
|
||||
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
|
||||
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
|
||||
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
|
||||
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
|
||||
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
|
||||
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
|
||||
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
|
||||
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
|
||||
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
|
||||
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
|
||||
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
|
||||
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
|
||||
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
|
||||
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
|
||||
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
|
||||
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
|
||||
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
|
||||
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
|
||||
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
|
||||
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
|
||||
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
|
||||
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
|
||||
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
|
||||
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
|
||||
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
|
||||
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
|
||||
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
|
||||
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
|
||||
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
|
||||
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
|
||||
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
|
||||
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
|
||||
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
|
||||
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
|
||||
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
|
||||
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
|
||||
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
|
||||
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
|
||||
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
|
||||
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
|
||||
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
|
||||
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
|
||||
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
Err3ClientIdentifierNotValid = Code{Code: 0x02}
|
||||
Err3ServerUnavailable = Code{Code: 0x03}
|
||||
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
|
||||
Err3NotAuthorized = Code{Code: 0x05}
|
||||
|
||||
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
|
||||
// This is required because MQTTv3 has different return byte specification.
|
||||
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
|
||||
V5CodesToV3 = map[Code]Code{
|
||||
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
|
||||
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
|
||||
ErrServerUnavailable: Err3ServerUnavailable,
|
||||
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
|
||||
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
|
||||
ErrBadUsernameOrPassword: Err3NotAuthorized,
|
||||
}
|
||||
)
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCodesString(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "test",
|
||||
Code: 0x1,
|
||||
}
|
||||
|
||||
require.Equal(t, "test", c.String())
|
||||
}
|
||||
|
||||
func TestCodesErrorr(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "error",
|
||||
Code: 0x1,
|
||||
}
|
||||
|
||||
require.Equal(t, "error", error(c).Error())
|
||||
}
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
|
||||
type FixedHeader struct {
|
||||
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
|
||||
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte `json:"qos"` // indicates the quality of service expected.
|
||||
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool `json:"retain"` // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// Decode extracts the specification bits from the header byte.
|
||||
func (fh *FixedHeader) Decode(hb byte) error {
|
||||
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
|
||||
|
||||
switch fh.Type {
|
||||
case Publish:
|
||||
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
|
||||
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
|
||||
}
|
||||
|
||||
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
|
||||
fh.Qos = (hb >> 1) & 0x03 // qos flag
|
||||
fh.Retain = hb&0x01 > 0 // is retain flag
|
||||
case Pubrel:
|
||||
fallthrough
|
||||
case Subscribe:
|
||||
fallthrough
|
||||
case Unsubscribe:
|
||||
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
|
||||
return ErrMalformedFlags
|
||||
}
|
||||
|
||||
fh.Qos = (hb >> 1) & 0x03
|
||||
default:
|
||||
if (hb>>0)&0x01 != 0 ||
|
||||
(hb>>1)&0x01 != 0 ||
|
||||
(hb>>2)&0x01 != 0 ||
|
||||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
|
||||
return ErrMalformedFlags
|
||||
}
|
||||
}
|
||||
|
||||
if fh.Qos == 0 && fh.Dup {
|
||||
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,237 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fixedHeaderTable struct {
|
||||
desc string
|
||||
rawBytes []byte
|
||||
header FixedHeader
|
||||
packetError bool
|
||||
expect error
|
||||
}
|
||||
|
||||
var fixedHeaderExpected = []fixedHeaderTable{
|
||||
{
|
||||
desc: "connect",
|
||||
rawBytes: []byte{Connect << 4, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "connack",
|
||||
rawBytes: []byte{Connack << 4, 0x00},
|
||||
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish",
|
||||
rawBytes: []byte{Publish << 4, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 1",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 1 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 2",
|
||||
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 2 retain",
|
||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 0",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
expect: ErrProtocolViolationDupNoQos,
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 0 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
|
||||
expect: ErrProtocolViolationDupNoQos,
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 1 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 2 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "puback",
|
||||
rawBytes: []byte{Puback << 4, 0x00},
|
||||
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubrec",
|
||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubrel",
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubcomp",
|
||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "subscribe",
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "suback",
|
||||
rawBytes: []byte{Suback << 4, 0x00},
|
||||
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "unsubscribe",
|
||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "unsuback",
|
||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
||||
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pingreq",
|
||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pingresp",
|
||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "disconnect",
|
||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
||||
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "auth",
|
||||
rawBytes: []byte{Auth << 4, 0x00},
|
||||
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
|
||||
// remaining length
|
||||
{
|
||||
desc: "remaining length 10",
|
||||
rawBytes: []byte{Publish << 4, 0x0a},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 512",
|
||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 978",
|
||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 20202",
|
||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
|
||||
},
|
||||
{
|
||||
desc: "remaining length oversize",
|
||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
|
||||
packetError: true,
|
||||
},
|
||||
|
||||
// Invalid flags for packet
|
||||
{
|
||||
desc: "invalid type dup is true",
|
||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid type qos is 1",
|
||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid type retain is true",
|
||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid publish qos bits 1 + 2 set",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
|
||||
header: FixedHeader{Type: Publish},
|
||||
expect: ErrProtocolViolationQosOutOfRange,
|
||||
},
|
||||
{
|
||||
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Qos: 1},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Qos: 1},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
}
|
||||
|
||||
func TestFixedHeaderEncode(t *testing.T) {
|
||||
for _, wanted := range fixedHeaderExpected {
|
||||
t.Run(wanted.desc, func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
wanted.header.Encode(buf)
|
||||
if wanted.expect == nil {
|
||||
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
|
||||
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedHeaderDecode(t *testing.T) {
|
||||
for _, wanted := range fixedHeaderExpected {
|
||||
t.Run(wanted.desc, func(t *testing.T) {
|
||||
fh := new(FixedHeader)
|
||||
err := fh.Decode(wanted.rawBytes[0])
|
||||
if wanted.expect != nil {
|
||||
require.Equal(t, wanted.expect, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.header.Type, fh.Type)
|
||||
require.Equal(t, wanted.header.Dup, fh.Dup)
|
||||
require.Equal(t, wanted.header.Qos, fh.Qos)
|
||||
require.Equal(t, wanted.header.Retain, fh.Retain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,505 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const pkInfo = "packet type %v, %s"
|
||||
|
||||
var packetList = []byte{
|
||||
Connect,
|
||||
Connack,
|
||||
Publish,
|
||||
Puback,
|
||||
Pubrec,
|
||||
Pubrel,
|
||||
Pubcomp,
|
||||
Subscribe,
|
||||
Suback,
|
||||
Unsubscribe,
|
||||
Unsuback,
|
||||
Pingreq,
|
||||
Pingresp,
|
||||
Disconnect,
|
||||
Auth,
|
||||
}
|
||||
|
||||
var pkTable = []TPacketCase{
|
||||
TPacketData[Connect].Get(TConnectMqtt311),
|
||||
TPacketData[Connect].Get(TConnectMqtt5),
|
||||
TPacketData[Connect].Get(TConnectUserPassLWT),
|
||||
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
|
||||
TPacketData[Connack].Get(TConnackAcceptedNoSession),
|
||||
TPacketData[Publish].Get(TPublishBasic),
|
||||
TPacketData[Publish].Get(TPublishMqtt5),
|
||||
TPacketData[Puback].Get(TPuback),
|
||||
TPacketData[Pubrec].Get(TPubrec),
|
||||
TPacketData[Pubrel].Get(TPubrel),
|
||||
TPacketData[Pubcomp].Get(TPubcomp),
|
||||
TPacketData[Subscribe].Get(TSubscribe),
|
||||
TPacketData[Subscribe].Get(TSubscribeMqtt5),
|
||||
TPacketData[Suback].Get(TSuback),
|
||||
TPacketData[Unsubscribe].Get(TUnsubscribe),
|
||||
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
|
||||
TPacketData[Pingreq].Get(TPingreq),
|
||||
TPacketData[Pingresp].Get(TPingresp),
|
||||
TPacketData[Disconnect].Get(TDisconnect),
|
||||
TPacketData[Disconnect].Get(TDisconnectMqtt5),
|
||||
}
|
||||
|
||||
func TestNewPackets(t *testing.T) {
|
||||
s := NewPackets()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestPacketsAdd(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
}
|
||||
|
||||
func TestPacketsGet(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
|
||||
pk, ok := s.Get("cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a1", pk.TopicName)
|
||||
}
|
||||
|
||||
func TestPacketsGetAll(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
s.Add("cl3", Packet{TopicName: "a3"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Contains(t, s.internal, "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 3)
|
||||
}
|
||||
|
||||
func TestPacketsLen(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Equal(t, 2, s.Len())
|
||||
}
|
||||
|
||||
func TestSPacketsDelete(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
|
||||
s.Delete("cl1")
|
||||
_, ok := s.Get("cl1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestFormatPacketID(t *testing.T) {
|
||||
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
|
||||
packet := &Packet{PacketID: id}
|
||||
require.Equal(t, fmt.Sprint(id), packet.FormatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
|
||||
p := &Subscription{
|
||||
Qos: 2,
|
||||
NoLocal: true,
|
||||
RetainAsPublished: true,
|
||||
RetainHandling: 2,
|
||||
}
|
||||
x := new(Subscription)
|
||||
x.decode(p.encode())
|
||||
require.Equal(t, *p, *x)
|
||||
|
||||
p = &Subscription{
|
||||
Qos: 1,
|
||||
NoLocal: false,
|
||||
RetainAsPublished: false,
|
||||
RetainHandling: 1,
|
||||
}
|
||||
x = new(Subscription)
|
||||
x.decode(p.encode())
|
||||
require.Equal(t, *p, *x)
|
||||
}
|
||||
|
||||
func TestPacketEncode(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if !encodeTestOK(wanted) {
|
||||
return
|
||||
}
|
||||
|
||||
pk := new(Packet)
|
||||
copier.Copy(pk, wanted.Packet)
|
||||
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
case Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
case Auth:
|
||||
err = pk.AuthEncode(buf)
|
||||
}
|
||||
if wanted.Expect != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
|
||||
encoded := buf.Bytes()
|
||||
|
||||
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
|
||||
if len(wanted.ActualBytes) > 0 {
|
||||
wanted.RawBytes = wanted.ActualBytes
|
||||
}
|
||||
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketDecode(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if !decodeTestOK(wanted) {
|
||||
return
|
||||
}
|
||||
|
||||
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
if len(wanted.RawBytes) > 0 {
|
||||
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
|
||||
}
|
||||
|
||||
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
|
||||
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
|
||||
}
|
||||
|
||||
buf := wanted.RawBytes[2:]
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectDecode(buf)
|
||||
case Connack:
|
||||
err = pk.ConnackDecode(buf)
|
||||
case Publish:
|
||||
err = pk.PublishDecode(buf)
|
||||
case Puback:
|
||||
err = pk.PubackDecode(buf)
|
||||
case Pubrec:
|
||||
err = pk.PubrecDecode(buf)
|
||||
case Pubrel:
|
||||
err = pk.PubrelDecode(buf)
|
||||
case Pubcomp:
|
||||
err = pk.PubcompDecode(buf)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeDecode(buf)
|
||||
case Suback:
|
||||
err = pk.SubackDecode(buf)
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(buf)
|
||||
case Unsuback:
|
||||
err = pk.UnsubackDecode(buf)
|
||||
case Pingreq:
|
||||
err = pk.PingreqDecode(buf)
|
||||
case Pingresp:
|
||||
err = pk.PingrespDecode(buf)
|
||||
case Disconnect:
|
||||
err = pk.DisconnectDecode(buf)
|
||||
case Auth:
|
||||
err = pk.AuthDecode(buf)
|
||||
}
|
||||
|
||||
if wanted.FailFirst != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
if pkt == Connect {
|
||||
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
|
||||
// against it unless it's a connect packet.
|
||||
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
|
||||
}
|
||||
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
|
||||
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if wanted.Group == "validate" || wanted.Primary {
|
||||
pk := wanted.Packet
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectValidate()
|
||||
case Publish:
|
||||
err = pk.PublishValidate(1024)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeValidate()
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeValidate()
|
||||
case Auth:
|
||||
err = pk.AuthValidate()
|
||||
}
|
||||
|
||||
if wanted.Expect != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckValidatePubrec(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
CodeNoMatchingSubscribers.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicNameInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
ErrQuotaExceeded.Code,
|
||||
ErrPayloadFormatInvalid.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidatePubrel(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
ErrPacketIdentifierNotFound.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidatePubcomp(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
ErrPacketIdentifierNotFound.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidateSuback(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeGrantedQos0.Code,
|
||||
CodeGrantedQos1.Code,
|
||||
CodeGrantedQos2.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicFilterInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
ErrQuotaExceeded.Code,
|
||||
ErrSharedSubscriptionsNotSupported.Code,
|
||||
ErrSubscriptionIdentifiersNotSupported.Code,
|
||||
ErrWildcardSubscriptionsNotSupported.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidateUnsuback(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
CodeNoSubscriptionExisted.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicFilterInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestReasonCodeValidMisc(t *testing.T) {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
pkc := tt.Packet.Copy(true)
|
||||
|
||||
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
|
||||
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
|
||||
|
||||
pkcc := tt.Packet.Copy(false)
|
||||
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSubscription(t *testing.T) {
|
||||
sub := Subscription{
|
||||
Filter: "a/b/c",
|
||||
RetainHandling: 0,
|
||||
Qos: 0,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: false,
|
||||
Identifier: 1,
|
||||
}
|
||||
|
||||
sub2 := Subscription{
|
||||
Filter: "a/b/d",
|
||||
RetainHandling: 0,
|
||||
Qos: 2,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: true,
|
||||
Identifier: 2,
|
||||
}
|
||||
|
||||
expect := Subscription{
|
||||
Filter: "a/b/c",
|
||||
RetainHandling: 0,
|
||||
Qos: 2,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: true,
|
||||
Identifier: 1,
|
||||
Identifiers: map[string]int{
|
||||
"a/b/c": 1,
|
||||
"a/b/d": 2,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expect, sub.Merge(sub2))
|
||||
}
|
||||
|
|
@ -1,479 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
PropPayloadFormat byte = 1
|
||||
PropMessageExpiryInterval byte = 2
|
||||
PropContentType byte = 3
|
||||
PropResponseTopic byte = 8
|
||||
PropCorrelationData byte = 9
|
||||
PropSubscriptionIdentifier byte = 11
|
||||
PropSessionExpiryInterval byte = 17
|
||||
PropAssignedClientID byte = 18
|
||||
PropServerKeepAlive byte = 19
|
||||
PropAuthenticationMethod byte = 21
|
||||
PropAuthenticationData byte = 22
|
||||
PropRequestProblemInfo byte = 23
|
||||
PropWillDelayInterval byte = 24
|
||||
PropRequestResponseInfo byte = 25
|
||||
PropResponseInfo byte = 26
|
||||
PropServerReference byte = 28
|
||||
PropReasonString byte = 31
|
||||
PropReceiveMaximum byte = 33
|
||||
PropTopicAliasMaximum byte = 34
|
||||
PropTopicAlias byte = 35
|
||||
PropMaximumQos byte = 36
|
||||
PropRetainAvailable byte = 37
|
||||
PropUser byte = 38
|
||||
PropMaximumPacketSize byte = 39
|
||||
PropWildcardSubAvailable byte = 40
|
||||
PropSubIDAvailable byte = 41
|
||||
PropSharedSubAvailable byte = 42
|
||||
)
|
||||
|
||||
// validPacketProperties indicates which properties are valid for which packet types.
|
||||
var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropPayloadFormat: {Publish: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1},
|
||||
PropContentType: {Publish: 1},
|
||||
PropResponseTopic: {Publish: 1},
|
||||
PropCorrelationData: {Publish: 1},
|
||||
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
|
||||
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
|
||||
PropAssignedClientID: {Connack: 1},
|
||||
PropServerKeepAlive: {Connack: 1},
|
||||
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropRequestProblemInfo: {Connect: 1},
|
||||
PropWillDelayInterval: {Connect: 1},
|
||||
PropRequestResponseInfo: {Connect: 1},
|
||||
PropResponseInfo: {Connack: 1},
|
||||
PropServerReference: {Connack: 1, Disconnect: 1},
|
||||
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropReceiveMaximum: {Connect: 1, Connack: 1},
|
||||
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
|
||||
PropTopicAlias: {Publish: 1},
|
||||
PropMaximumQos: {Connack: 1},
|
||||
PropRetainAvailable: {Connack: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropMaximumPacketSize: {Connect: 1, Connack: 1},
|
||||
PropWildcardSubAvailable: {Connack: 1},
|
||||
PropSubIDAvailable: {Connack: 1},
|
||||
PropSharedSubAvailable: {Connack: 1},
|
||||
}
|
||||
|
||||
// UserProperty is an arbitrary key-value pair for a packet user properties array.
|
||||
type UserProperty struct { // [MQTT-1.5.7-1]
|
||||
Key string `json:"k"`
|
||||
Val string `json:"v"`
|
||||
}
|
||||
|
||||
// Properties contains all of the mqtt v5 properties available for a packet.
|
||||
// Some properties have valid values of 0 or not-present. In this case, we opt for
|
||||
// property flags to indicate the usage of property.
|
||||
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
|
||||
type Properties struct {
|
||||
CorrelationData []byte `json:"cd"`
|
||||
SubscriptionIdentifier []int `json:"si"`
|
||||
AuthenticationData []byte `json:"ad"`
|
||||
User []UserProperty `json:"user"`
|
||||
ContentType string `json:"ct"`
|
||||
ResponseTopic string `json:"rt"`
|
||||
AssignedClientID string `json:"aci"`
|
||||
AuthenticationMethod string `json:"am"`
|
||||
ResponseInfo string `json:"ri"`
|
||||
ServerReference string `json:"sr"`
|
||||
ReasonString string `json:"rs"`
|
||||
MessageExpiryInterval uint32 `json:"me"`
|
||||
SessionExpiryInterval uint32 `json:"sei"`
|
||||
WillDelayInterval uint32 `json:"wdi"`
|
||||
MaximumPacketSize uint32 `json:"mps"`
|
||||
ServerKeepAlive uint16 `json:"ska"`
|
||||
ReceiveMaximum uint16 `json:"rm"`
|
||||
TopicAliasMaximum uint16 `json:"tam"`
|
||||
TopicAlias uint16 `json:"ta"`
|
||||
PayloadFormat byte `json:"pf"`
|
||||
PayloadFormatFlag bool `json:"fpf"`
|
||||
SessionExpiryIntervalFlag bool `json:"fsei"`
|
||||
ServerKeepAliveFlag bool `json:"fska"`
|
||||
RequestProblemInfo byte `json:"rpi"`
|
||||
RequestProblemInfoFlag bool `json:"frpi"`
|
||||
RequestResponseInfo byte `json:"rri"`
|
||||
TopicAliasFlag bool `json:"fta"`
|
||||
MaximumQos byte `json:"mqos"`
|
||||
MaximumQosFlag bool `json:"fmqos"`
|
||||
RetainAvailable byte `json:"ra"`
|
||||
RetainAvailableFlag bool `json:"fra"`
|
||||
WildcardSubAvailable byte `json:"wsa"`
|
||||
WildcardSubAvailableFlag bool `json:"fwsa"`
|
||||
SubIDAvailable byte `json:"sida"`
|
||||
SubIDAvailableFlag bool `json:"fsida"`
|
||||
SharedSubAvailable byte `json:"ssa"`
|
||||
SharedSubAvailableFlag bool `json:"fssa"`
|
||||
}
|
||||
|
||||
// Copy creates a new Properties struct with copies of the values.
|
||||
func (p *Properties) Copy(allowTransfer bool) Properties {
|
||||
pr := Properties{
|
||||
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
|
||||
PayloadFormatFlag: p.PayloadFormatFlag,
|
||||
MessageExpiryInterval: p.MessageExpiryInterval,
|
||||
ContentType: p.ContentType, // [MQTT-3.3.2-20]
|
||||
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
|
||||
SessionExpiryInterval: p.SessionExpiryInterval,
|
||||
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
|
||||
AssignedClientID: p.AssignedClientID,
|
||||
ServerKeepAlive: p.ServerKeepAlive,
|
||||
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
|
||||
AuthenticationMethod: p.AuthenticationMethod,
|
||||
RequestProblemInfo: p.RequestProblemInfo,
|
||||
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
|
||||
WillDelayInterval: p.WillDelayInterval,
|
||||
RequestResponseInfo: p.RequestResponseInfo,
|
||||
ResponseInfo: p.ResponseInfo,
|
||||
ServerReference: p.ServerReference,
|
||||
ReasonString: p.ReasonString,
|
||||
ReceiveMaximum: p.ReceiveMaximum,
|
||||
TopicAliasMaximum: p.TopicAliasMaximum,
|
||||
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
|
||||
MaximumQos: p.MaximumQos,
|
||||
MaximumQosFlag: p.MaximumQosFlag,
|
||||
RetainAvailable: p.RetainAvailable,
|
||||
RetainAvailableFlag: p.RetainAvailableFlag,
|
||||
MaximumPacketSize: p.MaximumPacketSize,
|
||||
WildcardSubAvailable: p.WildcardSubAvailable,
|
||||
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
|
||||
SubIDAvailable: p.SubIDAvailable,
|
||||
SubIDAvailableFlag: p.SubIDAvailableFlag,
|
||||
SharedSubAvailable: p.SharedSubAvailable,
|
||||
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
|
||||
}
|
||||
|
||||
if allowTransfer {
|
||||
pr.TopicAlias = p.TopicAlias
|
||||
pr.TopicAliasFlag = p.TopicAliasFlag
|
||||
}
|
||||
|
||||
if len(p.CorrelationData) > 0 {
|
||||
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
|
||||
}
|
||||
|
||||
if len(p.SubscriptionIdentifier) > 0 {
|
||||
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
|
||||
}
|
||||
|
||||
if len(p.AuthenticationData) > 0 {
|
||||
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
|
||||
}
|
||||
|
||||
if len(p.User) > 0 {
|
||||
pr.User = []UserProperty{}
|
||||
for _, v := range p.User {
|
||||
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
|
||||
Key: v.Key,
|
||||
Val: v.Val,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return pr
|
||||
}
|
||||
|
||||
// canEncode returns true if the property type is valid for the packet type.
|
||||
func (p *Properties) canEncode(pkt byte, k byte) bool {
|
||||
return validPacketProperties[k][pkt] == 1
|
||||
}
|
||||
|
||||
// Encode encodes properties into a bytes buffer.
|
||||
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
pkt := pk.FixedHeader.Type
|
||||
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
|
||||
buf.WriteByte(PropMessageExpiryInterval)
|
||||
buf.Write(encodeUint32(p.MessageExpiryInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
|
||||
buf.WriteByte(PropContentType)
|
||||
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseTopic)
|
||||
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropCorrelationData)
|
||||
buf.Write(encodeBytes(p.CorrelationData))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
|
||||
for _, v := range p.SubscriptionIdentifier {
|
||||
if v > 0 {
|
||||
buf.WriteByte(PropSubscriptionIdentifier)
|
||||
encodeLength(&buf, int64(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
|
||||
buf.WriteByte(PropSessionExpiryInterval)
|
||||
buf.Write(encodeUint32(p.SessionExpiryInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
|
||||
buf.WriteByte(PropAssignedClientID)
|
||||
buf.Write(encodeString(p.AssignedClientID))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
|
||||
buf.WriteByte(PropServerKeepAlive)
|
||||
buf.Write(encodeUint16(p.ServerKeepAlive))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
|
||||
buf.WriteByte(PropAuthenticationMethod)
|
||||
buf.Write(encodeString(p.AuthenticationMethod))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
|
||||
buf.WriteByte(PropAuthenticationData)
|
||||
buf.Write(encodeBytes(p.AuthenticationData))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
|
||||
buf.WriteByte(PropRequestProblemInfo)
|
||||
buf.WriteByte(p.RequestProblemInfo)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
|
||||
buf.WriteByte(PropWillDelayInterval)
|
||||
buf.Write(encodeUint32(p.WillDelayInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
|
||||
buf.WriteByte(PropRequestResponseInfo)
|
||||
buf.WriteByte(p.RequestResponseInfo)
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseInfo)
|
||||
buf.Write(encodeString(p.ResponseInfo))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
|
||||
buf.WriteByte(PropServerReference)
|
||||
buf.Write(encodeString(p.ServerReference))
|
||||
}
|
||||
|
||||
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
|
||||
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
b := encodeString(p.ReasonString)
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
|
||||
buf.WriteByte(PropReasonString)
|
||||
buf.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
|
||||
buf.WriteByte(PropReceiveMaximum)
|
||||
buf.Write(encodeUint16(p.ReceiveMaximum))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
|
||||
buf.WriteByte(PropTopicAliasMaximum)
|
||||
buf.Write(encodeUint16(p.TopicAliasMaximum))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
|
||||
buf.WriteByte(PropTopicAlias)
|
||||
buf.Write(encodeUint16(p.TopicAlias))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
|
||||
buf.WriteByte(PropMaximumQos)
|
||||
buf.WriteByte(p.MaximumQos)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
|
||||
buf.WriteByte(PropRetainAvailable)
|
||||
buf.WriteByte(p.RetainAvailable)
|
||||
}
|
||||
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
pb.Write(encodeString(v.Key))
|
||||
pb.Write(encodeString(v.Val))
|
||||
}
|
||||
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
|
||||
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
|
||||
buf.Write(pb.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
|
||||
buf.WriteByte(PropMaximumPacketSize)
|
||||
buf.Write(encodeUint32(p.MaximumPacketSize))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
|
||||
buf.WriteByte(PropWildcardSubAvailable)
|
||||
buf.WriteByte(p.WildcardSubAvailable)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
|
||||
buf.WriteByte(PropSubIDAvailable)
|
||||
buf.WriteByte(p.SubIDAvailable)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
|
||||
buf.WriteByte(PropSharedSubAvailable)
|
||||
buf.WriteByte(p.SharedSubAvailable)
|
||||
}
|
||||
|
||||
encodeLength(b, int64(buf.Len()))
|
||||
buf.WriteTo(b) // [MQTT-3.1.3-10]
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
if p == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var bu int
|
||||
n, bu, err = DecodeLength(b)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
bt := b.Bytes()
|
||||
var k byte
|
||||
for offset := 0; offset < n; {
|
||||
k, offset, err = decodeByte(bt, offset)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if _, ok := validPacketProperties[k][pk]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
switch k {
|
||||
case PropPayloadFormat:
|
||||
p.PayloadFormat, offset, err = decodeByte(bt, offset)
|
||||
p.PayloadFormatFlag = true
|
||||
case PropMessageExpiryInterval:
|
||||
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
|
||||
case PropContentType:
|
||||
p.ContentType, offset, err = decodeString(bt, offset)
|
||||
case PropResponseTopic:
|
||||
p.ResponseTopic, offset, err = decodeString(bt, offset)
|
||||
case PropCorrelationData:
|
||||
p.CorrelationData, offset, err = decodeBytes(bt, offset)
|
||||
case PropSubscriptionIdentifier:
|
||||
if p.SubscriptionIdentifier == nil {
|
||||
p.SubscriptionIdentifier = []int{}
|
||||
}
|
||||
|
||||
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
|
||||
offset += bu
|
||||
case PropSessionExpiryInterval:
|
||||
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
|
||||
p.SessionExpiryIntervalFlag = true
|
||||
case PropAssignedClientID:
|
||||
p.AssignedClientID, offset, err = decodeString(bt, offset)
|
||||
case PropServerKeepAlive:
|
||||
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
|
||||
p.ServerKeepAliveFlag = true
|
||||
case PropAuthenticationMethod:
|
||||
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
|
||||
case PropAuthenticationData:
|
||||
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
|
||||
case PropRequestProblemInfo:
|
||||
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
|
||||
p.RequestProblemInfoFlag = true
|
||||
case PropWillDelayInterval:
|
||||
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
|
||||
case PropRequestResponseInfo:
|
||||
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
|
||||
case PropResponseInfo:
|
||||
p.ResponseInfo, offset, err = decodeString(bt, offset)
|
||||
case PropServerReference:
|
||||
p.ServerReference, offset, err = decodeString(bt, offset)
|
||||
case PropReasonString:
|
||||
p.ReasonString, offset, err = decodeString(bt, offset)
|
||||
case PropReceiveMaximum:
|
||||
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
|
||||
case PropTopicAliasMaximum:
|
||||
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
|
||||
case PropTopicAlias:
|
||||
p.TopicAlias, offset, err = decodeUint16(bt, offset)
|
||||
p.TopicAliasFlag = true
|
||||
case PropMaximumQos:
|
||||
p.MaximumQos, offset, err = decodeByte(bt, offset)
|
||||
p.MaximumQosFlag = true
|
||||
case PropRetainAvailable:
|
||||
p.RetainAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.RetainAvailableFlag = true
|
||||
case PropUser:
|
||||
var k, v string
|
||||
k, offset, err = decodeString(bt, offset)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
v, offset, err = decodeString(bt, offset)
|
||||
p.User = append(p.User, UserProperty{Key: k, Val: v})
|
||||
case PropMaximumPacketSize:
|
||||
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
|
||||
case PropWildcardSubAvailable:
|
||||
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.WildcardSubAvailableFlag = true
|
||||
case PropSubIDAvailable:
|
||||
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.SubIDAvailableFlag = true
|
||||
case PropSharedSubAvailable:
|
||||
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.SharedSubAvailableFlag = true
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
}
|
||||
|
||||
return n + bu, nil
|
||||
}
|
||||
|
|
@ -1,333 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
propertiesStruct = Properties{
|
||||
PayloadFormat: byte(1), // UTF-8 Format
|
||||
PayloadFormatFlag: true,
|
||||
MessageExpiryInterval: uint32(2),
|
||||
ContentType: "text/plain",
|
||||
ResponseTopic: "a/b/c",
|
||||
CorrelationData: []byte("data"),
|
||||
SubscriptionIdentifier: []int{322122},
|
||||
SessionExpiryInterval: uint32(120),
|
||||
SessionExpiryIntervalFlag: true,
|
||||
AssignedClientID: "mochi-v5",
|
||||
ServerKeepAlive: uint16(20),
|
||||
ServerKeepAliveFlag: true,
|
||||
AuthenticationMethod: "SHA-1",
|
||||
AuthenticationData: []byte("auth-data"),
|
||||
RequestProblemInfo: byte(1),
|
||||
RequestProblemInfoFlag: true,
|
||||
WillDelayInterval: uint32(600),
|
||||
RequestResponseInfo: byte(1),
|
||||
ResponseInfo: "response",
|
||||
ServerReference: "mochi-2",
|
||||
ReasonString: "reason",
|
||||
ReceiveMaximum: uint16(500),
|
||||
TopicAliasMaximum: uint16(999),
|
||||
TopicAlias: uint16(3),
|
||||
TopicAliasFlag: true,
|
||||
MaximumQos: byte(1),
|
||||
MaximumQosFlag: true,
|
||||
RetainAvailable: byte(1),
|
||||
RetainAvailableFlag: true,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
{
|
||||
Key: "key2",
|
||||
Val: "value2",
|
||||
},
|
||||
},
|
||||
MaximumPacketSize: uint32(32000),
|
||||
WildcardSubAvailable: byte(1),
|
||||
WildcardSubAvailableFlag: true,
|
||||
SubIDAvailable: byte(1),
|
||||
SubIDAvailableFlag: true,
|
||||
SharedSubAvailable: byte(1),
|
||||
SharedSubAvailableFlag: true,
|
||||
}
|
||||
|
||||
propertiesBytes = []byte{
|
||||
172, 1, // VBI
|
||||
|
||||
// Payload Format (1) (vbi:2)
|
||||
1, 1,
|
||||
|
||||
// Message Expiry (2) (vbi:7)
|
||||
2, 0, 0, 0, 2,
|
||||
|
||||
// Content Type (3) (vbi:20)
|
||||
3,
|
||||
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
|
||||
|
||||
// Response Topic (8) (vbi:28)
|
||||
8,
|
||||
0, 5, 'a', '/', 'b', '/', 'c',
|
||||
|
||||
// Correlations Data (9) (vbi:35)
|
||||
9,
|
||||
0, 4, 'd', 'a', 't', 'a',
|
||||
|
||||
// Subscription Identifier (11) (vbi:39)
|
||||
11,
|
||||
202, 212, 19,
|
||||
|
||||
// Session Expiry Interval (17) (vbi:43)
|
||||
17,
|
||||
0, 0, 0, 120,
|
||||
|
||||
// Assigned Client ID (18) (vbi:55)
|
||||
18,
|
||||
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
|
||||
|
||||
// Server Keep Alive (19) (vbi:58)
|
||||
19,
|
||||
0, 20,
|
||||
|
||||
// Authentication Method (21) (vbi:66)
|
||||
21,
|
||||
0, 5, 'S', 'H', 'A', '-', '1',
|
||||
|
||||
// Authentication Data (22) (vbi:78)
|
||||
22,
|
||||
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
|
||||
|
||||
// Request Problem Info (23) (vbi:80)
|
||||
23, 1,
|
||||
|
||||
// Will Delay Interval (24) (vbi:85)
|
||||
24,
|
||||
0, 0, 2, 88,
|
||||
|
||||
// Request Response Info (25) (vbi:87)
|
||||
25, 1,
|
||||
|
||||
// Response Info (26) (vbi:98)
|
||||
26,
|
||||
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
|
||||
|
||||
// Server Reference (28) (vbi:108)
|
||||
28,
|
||||
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
|
||||
|
||||
// Reason String (31) (vbi:117)
|
||||
31,
|
||||
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
|
||||
|
||||
// Receive Maximum (33) (vbi:120)
|
||||
33,
|
||||
1, 244,
|
||||
|
||||
// Topic Alias Maximum (34) (vbi:123)
|
||||
34,
|
||||
3, 231,
|
||||
|
||||
// Topic Alias (35) (vbi:126)
|
||||
35,
|
||||
0, 3,
|
||||
|
||||
// Maximum Qos (36) (vbi:128)
|
||||
36, 1,
|
||||
|
||||
// Retain Available (37) (vbi: 130)
|
||||
37, 1,
|
||||
|
||||
// User Properties (38) (vbi:161)
|
||||
38,
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
38,
|
||||
0, 4, 'k', 'e', 'y', '2',
|
||||
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
|
||||
|
||||
// Maximum Packet Size (39) (vbi:166)
|
||||
39,
|
||||
0, 0, 125, 0,
|
||||
|
||||
// Wildcard Subscriptions Available (40) (vbi:168)
|
||||
40, 1,
|
||||
|
||||
// Subscription ID Available (41) (vbi:170)
|
||||
41, 1,
|
||||
|
||||
// Shared Subscriptions Available (42) (vbi:172)
|
||||
42, 1,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
validPacketProperties[PropPayloadFormat][Reserved] = 1
|
||||
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
|
||||
validPacketProperties[PropContentType][Reserved] = 1
|
||||
validPacketProperties[PropResponseTopic][Reserved] = 1
|
||||
validPacketProperties[PropCorrelationData][Reserved] = 1
|
||||
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
|
||||
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
|
||||
validPacketProperties[PropAssignedClientID][Reserved] = 1
|
||||
validPacketProperties[PropServerKeepAlive][Reserved] = 1
|
||||
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
|
||||
validPacketProperties[PropAuthenticationData][Reserved] = 1
|
||||
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
|
||||
validPacketProperties[PropWillDelayInterval][Reserved] = 1
|
||||
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
|
||||
validPacketProperties[PropResponseInfo][Reserved] = 1
|
||||
validPacketProperties[PropServerReference][Reserved] = 1
|
||||
validPacketProperties[PropReasonString][Reserved] = 1
|
||||
validPacketProperties[PropReceiveMaximum][Reserved] = 1
|
||||
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
|
||||
validPacketProperties[PropTopicAlias][Reserved] = 1
|
||||
validPacketProperties[PropMaximumQos][Reserved] = 1
|
||||
validPacketProperties[PropRetainAvailable][Reserved] = 1
|
||||
validPacketProperties[PropUser][Reserved] = 1
|
||||
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
|
||||
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
|
||||
validPacketProperties[PropSubIDAvailable][Reserved] = 1
|
||||
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
|
||||
}
|
||||
|
||||
func TestEncodeProperties(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
require.Equal(t, propertiesBytes, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
|
||||
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
|
||||
}
|
||||
|
||||
func TestEncodePropertiesNil(t *testing.T) {
|
||||
type tmp struct {
|
||||
p *Properties
|
||||
}
|
||||
|
||||
pr := tmp{}
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
|
||||
require.Equal(t, []byte{}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodeZeroProperties(t *testing.T) {
|
||||
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
|
||||
props := new(Properties)
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
require.Equal(t, []byte{0x00}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestDecodeProperties(t *testing.T) {
|
||||
b := bytes.NewBuffer(propertiesBytes)
|
||||
|
||||
props := new(Properties)
|
||||
n, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 172+2, n)
|
||||
require.EqualValues(t, propertiesStruct, *props)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesNil(t *testing.T) {
|
||||
b := bytes.NewBuffer(propertiesBytes)
|
||||
|
||||
type tmp struct {
|
||||
p *Properties
|
||||
}
|
||||
|
||||
pr := tmp{}
|
||||
n, err := pr.p.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{0})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, props, new(Properties))
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadKeyByte(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{64, 1})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{1, 99})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesGeneralFailure(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadUserProps(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCopyProperties(t *testing.T) {
|
||||
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
|
||||
}
|
||||
|
||||
func TestCopyPropertiesNoTransfer(t *testing.T) {
|
||||
pkA := propertiesStruct
|
||||
pkB := pkA.Copy(false)
|
||||
|
||||
// Properties which should never be transferred from one connection to another
|
||||
require.Equal(t, uint16(0), pkB.TopicAlias)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,33 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func encodeTestOK(wanted TPacketCase) bool {
|
||||
if wanted.RawBytes == nil {
|
||||
return false
|
||||
}
|
||||
if wanted.Group != "" && wanted.Group != "encode" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func decodeTestOK(wanted TPacketCase) bool {
|
||||
if wanted.Group != "" && wanted.Group != "decode" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestTPacketCaseGet(t *testing.T) {
|
||||
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
|
||||
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,61 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package system
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Info contains atomic counters and values for various server statistics
|
||||
// commonly found in $SYS topics (and others).
|
||||
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
|
||||
type Info struct {
|
||||
Version string `json:"version"` // the current version of the server
|
||||
Started int64 `json:"started"` // the time the server started in unix seconds
|
||||
Time int64 `json:"time"` // current time on the server
|
||||
Uptime int64 `json:"uptime"` // the number of seconds the server has been online
|
||||
BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started
|
||||
BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started
|
||||
ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients
|
||||
ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected
|
||||
ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected
|
||||
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
|
||||
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
|
||||
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
|
||||
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
|
||||
Retained int64 `json:"retained"` // total number of retained messages active on the broker
|
||||
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
|
||||
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
|
||||
Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker
|
||||
PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received
|
||||
PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started
|
||||
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
|
||||
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
|
||||
}
|
||||
|
||||
// Clone makes a copy of Info using atomic operation
|
||||
func (i *Info) Clone() *Info {
|
||||
return &Info{
|
||||
Version: i.Version,
|
||||
Started: atomic.LoadInt64(&i.Started),
|
||||
Time: atomic.LoadInt64(&i.Time),
|
||||
Uptime: atomic.LoadInt64(&i.Uptime),
|
||||
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
|
||||
BytesSent: atomic.LoadInt64(&i.BytesSent),
|
||||
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
|
||||
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
|
||||
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
|
||||
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
|
||||
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
|
||||
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
|
||||
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
|
||||
Retained: atomic.LoadInt64(&i.Retained),
|
||||
Inflight: atomic.LoadInt64(&i.Inflight),
|
||||
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
|
||||
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
|
||||
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
|
||||
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
|
||||
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
|
||||
Threads: atomic.LoadInt64(&i.Threads),
|
||||
}
|
||||
}
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
o := &Info{
|
||||
Version: "version",
|
||||
Started: 1,
|
||||
Time: 2,
|
||||
Uptime: 3,
|
||||
BytesReceived: 4,
|
||||
BytesSent: 5,
|
||||
ClientsConnected: 6,
|
||||
ClientsMaximum: 7,
|
||||
ClientsTotal: 8,
|
||||
ClientsDisconnected: 9,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
Retained: 12,
|
||||
Inflight: 13,
|
||||
InflightDropped: 14,
|
||||
Subscriptions: 15,
|
||||
PacketsReceived: 16,
|
||||
PacketsSent: 17,
|
||||
MemoryAlloc: 18,
|
||||
Threads: 19,
|
||||
}
|
||||
|
||||
n := o.Clone()
|
||||
|
||||
require.Equal(t, o, n)
|
||||
}
|
||||
|
|
@ -1,699 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
var (
|
||||
SharePrefix = "$SHARE" // the prefix indicating a share topic
|
||||
SysPrefix = "$SYS" // the prefix indicating a system info topic
|
||||
)
|
||||
|
||||
// TopicAliases contains inbound and outbound topic alias registrations.
|
||||
type TopicAliases struct {
|
||||
Inbound *InboundTopicAliases
|
||||
Outbound *OutboundTopicAliases
|
||||
}
|
||||
|
||||
// NewTopicAliases returns an instance of TopicAliases.
|
||||
func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
|
||||
return TopicAliases{
|
||||
Inbound: NewInboundTopicAliases(topicAliasMaximum),
|
||||
Outbound: NewOutboundTopicAliases(topicAliasMaximum),
|
||||
}
|
||||
}
|
||||
|
||||
// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
|
||||
func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
|
||||
return &InboundTopicAliases{
|
||||
maximum: topicAliasMaximum,
|
||||
internal: map[uint16]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// InboundTopicAliases contains a map of topic aliases received from the client.
|
||||
type InboundTopicAliases struct {
|
||||
internal map[uint16]string
|
||||
sync.RWMutex
|
||||
maximum uint16
|
||||
}
|
||||
|
||||
// Set sets a new alias for a specific topic.
|
||||
func (a *InboundTopicAliases) Set(id uint16, topic string) string {
|
||||
a.Lock()
|
||||
defer a.Unlock()
|
||||
|
||||
if a.maximum == 0 {
|
||||
return topic // ?
|
||||
}
|
||||
|
||||
if existing, ok := a.internal[id]; ok && topic == "" {
|
||||
return existing
|
||||
}
|
||||
|
||||
a.internal[id] = topic
|
||||
return topic
|
||||
}
|
||||
|
||||
// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
|
||||
type OutboundTopicAliases struct {
|
||||
internal map[string]uint16
|
||||
sync.RWMutex
|
||||
cursor uint32
|
||||
maximum uint16
|
||||
}
|
||||
|
||||
// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
|
||||
func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
|
||||
return &OutboundTopicAliases{
|
||||
maximum: topicAliasMaximum,
|
||||
internal: map[string]uint16{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a new topic alias for a topic and returns the alias value, and a boolean
|
||||
// indicating if the alias already existed.
|
||||
func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
|
||||
a.Lock()
|
||||
defer a.Unlock()
|
||||
|
||||
if a.maximum == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if i, ok := a.internal[topic]; ok {
|
||||
return i, true
|
||||
}
|
||||
|
||||
i := atomic.LoadUint32(&a.cursor)
|
||||
if i+1 > uint32(a.maximum) {
|
||||
// if i+1 > math.MaxUint16 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
a.internal[topic] = uint16(i) + 1
|
||||
atomic.StoreUint32(&a.cursor, i+1)
|
||||
return uint16(i) + 1, false
|
||||
}
|
||||
|
||||
// SharedSubscriptions contains a map of subscriptions to a shared filter,
|
||||
// keyed on share group then client id.
|
||||
type SharedSubscriptions struct {
|
||||
internal map[string]map[string]packets.Subscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSharedSubscriptions returns a new instance of Subscriptions.
|
||||
func NewSharedSubscriptions() *SharedSubscriptions {
|
||||
return &SharedSubscriptions{
|
||||
internal: map[string]map[string]packets.Subscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new shared subscription for a group and client id pair.
|
||||
func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
if _, ok := s.internal[group]; !ok {
|
||||
s.internal[group] = map[string]packets.Subscription{}
|
||||
}
|
||||
s.internal[group][id] = val
|
||||
}
|
||||
|
||||
// Delete deletes a client id from a shared subscription group.
|
||||
func (s *SharedSubscriptions) Delete(group, id string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal[group], id)
|
||||
if len(s.internal[group]) == 0 {
|
||||
delete(s.internal, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the subscription properties for a client id in a share group, if one exists.
|
||||
func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
if _, ok := s.internal[group]; !ok {
|
||||
return val, ok
|
||||
}
|
||||
|
||||
val, ok = s.internal[group][id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// GroupLen returns the number of groups subscribed to the filter.
|
||||
func (s *SharedSubscriptions) GroupLen() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Len returns the total number of shared subscriptions to a filter across all groups.
|
||||
func (s *SharedSubscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
n := 0
|
||||
for _, group := range s.internal {
|
||||
n += len(group)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// GetAll returns all shared subscription groups and their subscriptions.
|
||||
func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[string]map[string]packets.Subscription{}
|
||||
for group, subs := range s.internal {
|
||||
if _, ok := m[group]; !ok {
|
||||
m[group] = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for id, sub := range subs {
|
||||
m[group][id] = sub
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Subscriptions is a map of subscriptions keyed on client.
|
||||
type Subscriptions struct {
|
||||
internal map[string]packets.Subscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSubscriptions returns a new instance of Subscriptions.
|
||||
func NewSubscriptions() *Subscriptions {
|
||||
return &Subscriptions{
|
||||
internal: map[string]packets.Subscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new subscription for a client. ID can be a filter in the
|
||||
// case this map is client state, or a client id if particle state.
|
||||
func (s *Subscriptions) Add(id string, val packets.Subscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.internal[id] = val
|
||||
}
|
||||
|
||||
// GetAll returns all subscriptions.
|
||||
func (s *Subscriptions) GetAll() map[string]packets.Subscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[string]packets.Subscription{}
|
||||
for k, v := range s.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns a subscriptions for a specific client or filter id.
|
||||
func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val, ok = s.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the number of subscriptions.
|
||||
func (s *Subscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a subscription by client or filter id.
|
||||
func (s *Subscriptions) Delete(id string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal, id)
|
||||
}
|
||||
|
||||
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
||||
type ClientSubscriptions map[string]packets.Subscription
|
||||
|
||||
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
||||
type Subscribers struct {
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
}
|
||||
|
||||
// SelectShared returns one subscriber for each shared subscription group.
|
||||
func (s *Subscribers) SelectShared() {
|
||||
s.SharedSelected = map[string]packets.Subscription{}
|
||||
for _, subs := range s.Shared {
|
||||
for client, sub := range subs {
|
||||
cls, ok := s.SharedSelected[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
s.SharedSelected[client] = cls.Merge(sub)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSharedSelected merges the selected subscribers for a shared subscription group
|
||||
// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
|
||||
// due to have both types of subscription matching the same filter.
|
||||
func (s *Subscribers) MergeSharedSelected() {
|
||||
for client, sub := range s.SharedSelected {
|
||||
cls, ok := s.Subscriptions[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
s.Subscriptions[client] = cls.Merge(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
|
||||
type TopicsIndex struct {
|
||||
Retained *packets.Packets
|
||||
root *particle // a leaf containing a message and more leaves.
|
||||
}
|
||||
|
||||
// NewTopicsIndex returns a pointer to a new instance of Index.
|
||||
func NewTopicsIndex() *TopicsIndex {
|
||||
return &TopicsIndex{
|
||||
Retained: packets.NewPackets(),
|
||||
root: &particle{
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
prefix, _ := isolateParticle(subscription.Filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
group, _ := isolateParticle(subscription.Filter, 1)
|
||||
n := x.set(subscription.Filter, 2)
|
||||
_, existed = n.shared.Get(group, client)
|
||||
n.shared.Add(group, client, subscription)
|
||||
} else {
|
||||
n := x.set(subscription.Filter, 0)
|
||||
_, existed = n.subscriptions.Get(client)
|
||||
n.subscriptions.Add(client, subscription)
|
||||
}
|
||||
|
||||
return !existed
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription filter for a client, returning true if the
|
||||
// subscription existed.
|
||||
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var d int
|
||||
if strings.HasPrefix(filter, SharePrefix) {
|
||||
d = 2
|
||||
}
|
||||
|
||||
particle := x.seek(filter, d)
|
||||
if particle == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
group, _ := isolateParticle(filter, 1)
|
||||
particle.shared.Delete(group, client)
|
||||
} else {
|
||||
particle.subscriptions.Delete(client)
|
||||
}
|
||||
|
||||
x.trim(particle)
|
||||
return true
|
||||
}
|
||||
|
||||
// RetainMessage saves a message payload to the end of a topic address. Returns
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
n := x.set(pk.TopicName, 0)
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if len(pk.Payload) > 0 {
|
||||
n.retainPath = pk.TopicName
|
||||
x.Retained.Add(pk.TopicName, pk)
|
||||
return 1
|
||||
}
|
||||
|
||||
var out int64
|
||||
if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
|
||||
out = -1 // if a retained packet existed, return -1
|
||||
}
|
||||
|
||||
n.retainPath = ""
|
||||
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
|
||||
x.trim(n)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// set creates a topic address in the index and returns the final particle.
|
||||
func (x *TopicsIndex) set(topic string, d int) *particle {
|
||||
var key string
|
||||
var hasNext = true
|
||||
n := x.root
|
||||
for hasNext {
|
||||
key, hasNext = isolateParticle(topic, d)
|
||||
d++
|
||||
|
||||
p := n.particles.get(key)
|
||||
if p == nil {
|
||||
p = newParticle(key, n)
|
||||
n.particles.add(p)
|
||||
}
|
||||
n = p
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// seek finds the particle at a specific index in a topic filter.
|
||||
func (x *TopicsIndex) seek(filter string, d int) *particle {
|
||||
var key string
|
||||
var hasNext = true
|
||||
n := x.root
|
||||
for hasNext {
|
||||
key, hasNext = isolateParticle(filter, d)
|
||||
n = n.particles.get(key)
|
||||
d++
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// trim removes empty filter particles from the index.
|
||||
func (x *TopicsIndex) trim(n *particle) {
|
||||
for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len() == 0 {
|
||||
key := n.key
|
||||
n = n.parent
|
||||
n.particles.delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Messages returns a slice of any retained messages which match a filter.
|
||||
func (x *TopicsIndex) Messages(filter string) []packets.Packet {
|
||||
return x.scanMessages(filter, 0, nil, []packets.Packet{})
|
||||
}
|
||||
|
||||
// scanMessages returns all retained messages on topics matching a given filter.
|
||||
func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
|
||||
if n == nil {
|
||||
n = x.root
|
||||
}
|
||||
|
||||
if len(filter) == 0 || x.Retained.Len() == 0 {
|
||||
return pks
|
||||
}
|
||||
|
||||
if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
|
||||
if pk, ok := x.Retained.Get(filter); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
return pks
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(filter, d)
|
||||
if key == "+" || key == "#" || d == -1 {
|
||||
for _, adjacent := range n.particles.getAll() {
|
||||
if d == 0 && adjacent.key == SysPrefix {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasNext {
|
||||
if adjacent.retainPath != "" {
|
||||
if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasNext || (d >= 0 && key == "#") {
|
||||
pks = x.scanMessages(filter, d+1, adjacent, pks)
|
||||
}
|
||||
}
|
||||
return pks
|
||||
}
|
||||
|
||||
if particle := n.particles.get(key); particle != nil {
|
||||
if hasNext {
|
||||
return x.scanMessages(filter, d+1, particle, pks)
|
||||
}
|
||||
|
||||
if pk, ok := x.Retained.Get(particle.retainPath); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
}
|
||||
|
||||
return pks
|
||||
}
|
||||
|
||||
// Subscribers returns a map of clients who are subscribed to matching filters,
|
||||
// their subscription ids and highest qos.
|
||||
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
||||
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
})
|
||||
}
|
||||
|
||||
// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
|
||||
func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
|
||||
if n == nil {
|
||||
n = x.root
|
||||
}
|
||||
|
||||
if len(topic) == 0 {
|
||||
return subs
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(topic, d)
|
||||
for _, partKey := range []string{key, "+", "#"} {
|
||||
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
|
||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||
}
|
||||
|
||||
if hasNext {
|
||||
x.scanSubscribers(topic, d+1, particle, subs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return subs
|
||||
}
|
||||
|
||||
// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
|
||||
func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
|
||||
if subs.Subscriptions == nil {
|
||||
subs.Subscriptions = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for client, sub := range particle.subscriptions.GetAll() {
|
||||
if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
|
||||
continue
|
||||
}
|
||||
|
||||
cls, ok := subs.Subscriptions[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
subs.Subscriptions[client] = cls.Merge(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
|
||||
func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
|
||||
if subs.Shared == nil {
|
||||
subs.Shared = map[string]map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for _, shares := range particle.shared.GetAll() {
|
||||
for client, sub := range shares {
|
||||
if _, ok := subs.Shared[sub.Filter]; !ok {
|
||||
subs.Shared[sub.Filter] = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
subs.Shared[sub.Filter][client] = sub
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||
var next, end int
|
||||
for i := 0; end > -1 && i <= d; i++ {
|
||||
end = strings.IndexRune(filter, '/')
|
||||
|
||||
switch {
|
||||
case d > -1 && i == d && end > -1:
|
||||
hasNext = true
|
||||
particle = filter[next:end]
|
||||
case end > -1:
|
||||
hasNext = false
|
||||
filter = filter[end+1:]
|
||||
default:
|
||||
hasNext = false
|
||||
particle = filter[next:]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsSharedFilter returns true if the filter uses the share prefix.
|
||||
func IsSharedFilter(filter string) bool {
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
return strings.EqualFold(prefix, SharePrefix)
|
||||
}
|
||||
|
||||
// IsValidFilter returns true if the filter is valid.
|
||||
func IsValidFilter(filter string, forPublish bool) bool {
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
|
||||
return false // [MQTT-4.7.3-1]
|
||||
}
|
||||
|
||||
if forPublish {
|
||||
if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
|
||||
// 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
|
||||
return false //[MQTT-3.3.2-2]
|
||||
}
|
||||
}
|
||||
|
||||
wildhash := strings.IndexRune(filter, '#')
|
||||
if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
|
||||
return false
|
||||
}
|
||||
|
||||
prefix, hasNext := isolateParticle(filter, 0)
|
||||
if !hasNext && strings.EqualFold(prefix, SharePrefix) {
|
||||
return false // [MQTT-4.8.2-1]
|
||||
}
|
||||
|
||||
if hasNext && strings.EqualFold(prefix, SharePrefix) {
|
||||
group, hasNext := isolateParticle(filter, 1)
|
||||
if !hasNext {
|
||||
return false // [MQTT-4.8.2-1]
|
||||
}
|
||||
|
||||
if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
|
||||
return false // [MQTT-4.8.2-2]
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// particle is a child node on the tree.
|
||||
type particle struct {
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
}
|
||||
|
||||
// newParticle returns a pointer to a new instance of particle.
|
||||
func newParticle(key string, parent *particle) *particle {
|
||||
return &particle{
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
}
|
||||
}
|
||||
|
||||
// particles is a concurrency safe map of particles.
|
||||
type particles struct {
|
||||
internal map[string]*particle
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// newParticles returns a map of particles.
|
||||
func newParticles() particles {
|
||||
return particles{
|
||||
internal: map[string]*particle{},
|
||||
}
|
||||
}
|
||||
|
||||
// add adds a new particle.
|
||||
func (p *particles) add(val *particle) {
|
||||
p.Lock()
|
||||
p.internal[val.key] = val
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
// getAll returns all particles.
|
||||
func (p *particles) getAll() map[string]*particle {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
m := map[string]*particle{}
|
||||
for k, v := range p.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// get returns a particle by id (key).
|
||||
func (p *particles) get(id string) *particle {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
return p.internal[id]
|
||||
}
|
||||
|
||||
// len returns the number of particles.
|
||||
func (p *particles) len() int {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
val := len(p.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// delete removes a particle.
|
||||
func (p *particles) delete(id string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
delete(p.internal, id)
|
||||
}
|
||||
|
|
@ -1,842 +0,0 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testGroup = "testgroup"
|
||||
otherGroup = "other"
|
||||
)
|
||||
|
||||
func TestNewSharedSubscriptions(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsAdd(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsGet(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 2})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
|
||||
sub, ok := s.Get(testGroup, "cl2")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, byte(2), sub.Qos)
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsGetAll(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
|
||||
s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
require.Contains(t, s.internal, otherGroup)
|
||||
require.Contains(t, s.internal[otherGroup], "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 2)
|
||||
require.Len(t, subs[testGroup], 2)
|
||||
require.Len(t, subs[otherGroup], 1)
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsLen(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
|
||||
s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
require.Contains(t, s.internal, otherGroup)
|
||||
require.Contains(t, s.internal[otherGroup], "cl2")
|
||||
require.Equal(t, 3, s.Len())
|
||||
require.Equal(t, 2, s.GroupLen())
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsDelete(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 1})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
|
||||
require.Equal(t, 2, s.Len())
|
||||
|
||||
s.Delete(testGroup, "cl1")
|
||||
_, ok := s.Get(testGroup, "cl1")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 1, s.GroupLen())
|
||||
require.Equal(t, 1, s.Len())
|
||||
|
||||
s.Delete(testGroup, "cl2")
|
||||
_, ok = s.Get(testGroup, "cl2")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, s.GroupLen())
|
||||
require.Equal(t, 0, s.Len())
|
||||
}
|
||||
|
||||
func TestNewSubscriptions(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestSubscriptionsAdd(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
}
|
||||
|
||||
func TestSubscriptionsGet(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 2})
|
||||
s.Add("cl2", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
|
||||
sub, ok := s.Get("cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, byte(2), sub.Qos)
|
||||
}
|
||||
|
||||
func TestSubscriptionsGetAll(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 0})
|
||||
s.Add("cl2", packets.Subscription{Qos: 1})
|
||||
s.Add("cl3", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Contains(t, s.internal, "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 3)
|
||||
}
|
||||
|
||||
func TestSubscriptionsLen(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 0})
|
||||
s.Add("cl2", packets.Subscription{Qos: 1})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Equal(t, 2, s.Len())
|
||||
}
|
||||
|
||||
func TestSubscriptionsDelete(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 1})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
|
||||
s.Delete("cl1")
|
||||
_, ok := s.Get("cl1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestNewTopicsIndex(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
require.NotNil(t, index)
|
||||
require.NotNil(t, index.root)
|
||||
}
|
||||
|
||||
func BenchmarkNewTopicsIndex(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewTopicsIndex()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
tt := []struct {
|
||||
desc string
|
||||
client string
|
||||
filter string
|
||||
subscription packets.Subscription
|
||||
wasNew bool
|
||||
}{
|
||||
{
|
||||
desc: "subscribe",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "a/b/c", Qos: 2},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "subscribe existed",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "a/b/c", Qos: 1},
|
||||
wasNew: false,
|
||||
},
|
||||
{
|
||||
desc: "subscribe case sensitive didnt exist",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "A/B/c", Qos: 1},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard+ sub",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "d/+"},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard# sub",
|
||||
client: "cl1",
|
||||
subscription: packets.Subscription{Filter: "d/e/#"},
|
||||
wasNew: true,
|
||||
},
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
for _, tx := range tt {
|
||||
t.Run(tx.desc, func(t *testing.T) {
|
||||
require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription))
|
||||
})
|
||||
}
|
||||
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
client, exists := final.subscriptions.Get("cl1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(1), client.Qos)
|
||||
}
|
||||
|
||||
func TestSubscribeShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2})
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
client, exists := final.shared.Get("tmp", "cl1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(2), client.Qos)
|
||||
require.Equal(t, 0, final.subscriptions.Len())
|
||||
require.Equal(t, 1, final.shared.Len())
|
||||
}
|
||||
|
||||
func BenchmarkSubscribe(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSubscribeShared(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1})
|
||||
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1})
|
||||
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1})
|
||||
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1})
|
||||
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2})
|
||||
client, exists = index.root.particles.get("#").subscriptions.Get("cl3")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
ok := index.Unsubscribe("a/b/c/d", "cl1")
|
||||
require.True(t, ok)
|
||||
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
ok = index.Unsubscribe("d/e/f", "cl1")
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len())
|
||||
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestUnsubscribeNoCascade(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"})
|
||||
|
||||
ok := index.Unsubscribe("a/b/c/e/e", "cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, index.root.particles.len())
|
||||
|
||||
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUnsubscribeShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2})
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
client, exists := final.shared.Get("tmp", "cl1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(2), client.Qos)
|
||||
|
||||
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
|
||||
_, exists = final.shared.Get("tmp", "cl1")
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func BenchmarkUnsubscribe(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
b.StopTimer()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
b.StartTimer()
|
||||
index.Unsubscribe("a/b/c", "cl1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexSeek(t *testing.T) {
|
||||
filter := "a/b/c/d/e/f"
|
||||
index := NewTopicsIndex()
|
||||
k1 := index.set(filter, 0)
|
||||
require.Equal(t, "f", k1.key)
|
||||
k1.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
require.Equal(t, k1, index.seek(filter, 0))
|
||||
require.Nil(t, index.seek("d/e/f", 0))
|
||||
}
|
||||
|
||||
func TestIndexTrim(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
k1 := index.set("a/b/c", 0)
|
||||
require.Equal(t, "c", k1.key)
|
||||
k1.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
k2 := index.set("a/b/c/d/e/f", 0)
|
||||
require.Equal(t, "f", k2.key)
|
||||
k2.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
k3 := index.set("a/b", 0)
|
||||
require.Equal(t, "b", k3.key)
|
||||
k3.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
index.trim(k2)
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f"))
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b"))
|
||||
|
||||
k2.subscriptions.Delete("cl1")
|
||||
index.trim(k2)
|
||||
|
||||
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d"))
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
|
||||
k1.subscriptions.Delete("cl1")
|
||||
k3.subscriptions.Delete("cl1")
|
||||
index.trim(k2)
|
||||
require.Nil(t, index.root.particles.get("a"))
|
||||
}
|
||||
|
||||
func TestIndexSet(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
child := index.set("a/b/c", 0)
|
||||
require.Equal(t, "c", child.key)
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
|
||||
child = index.set("a/b/c/d/e", 0)
|
||||
require.Equal(t, "e", child.key)
|
||||
|
||||
child = index.set("a/b/c/c/a", 0)
|
||||
require.Equal(t, "a", child.key)
|
||||
}
|
||||
|
||||
func TestIndexSetPrefixed(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
child := index.set("/c", 0)
|
||||
require.Equal(t, "c", child.key)
|
||||
require.NotNil(t, index.root.particles.get("").particles.get("c"))
|
||||
}
|
||||
|
||||
func BenchmarkIndexSet(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.set("a/b/c", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetainMessage(t *testing.T) {
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{Retain: true},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello"),
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
r := index.RetainMessage(pk)
|
||||
require.Equal(t, int64(1), r)
|
||||
pke, ok := index.Retained.Get(pk.TopicName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, pk, pke)
|
||||
|
||||
pk2 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{Retain: true},
|
||||
TopicName: "a/b/d/f",
|
||||
Payload: []byte("hello"),
|
||||
}
|
||||
r = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), r)
|
||||
// The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message.
|
||||
r = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), r)
|
||||
|
||||
// Clear existing retained
|
||||
pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}}
|
||||
r = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(-1), r)
|
||||
_, ok = index.Retained.Get(pk.TopicName)
|
||||
require.False(t, ok)
|
||||
|
||||
// Clear no retained
|
||||
r = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(0), r)
|
||||
}
|
||||
|
||||
func BenchmarkRetainMessage(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsolateParticle(t *testing.T) {
|
||||
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
|
||||
require.Equal(t, "path", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
|
||||
require.Equal(t, "to", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
|
||||
require.Equal(t, "my", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
|
||||
require.Equal(t, "mqtt", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
|
||||
particle, hasNext = isolateParticle("/path/", 0)
|
||||
require.Equal(t, "", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("/path/", 1)
|
||||
require.Equal(t, "path", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("/path/", 2)
|
||||
require.Equal(t, "", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
|
||||
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
|
||||
require.Equal(t, "+", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
|
||||
require.Equal(t, "+", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
}
|
||||
|
||||
func BenchmarkIsolateParticle(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
isolateParticle("path/to/my/mqtt", 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSubscribers(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22})
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"})
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234})
|
||||
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
|
||||
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 4, len(subs.Subscriptions))
|
||||
require.Contains(t, subs.Subscriptions, "cl1")
|
||||
require.Contains(t, subs.Subscriptions, "cl2")
|
||||
require.Contains(t, subs.Subscriptions, "cl3")
|
||||
require.Contains(t, subs.Subscriptions, "cl4")
|
||||
|
||||
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
|
||||
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
|
||||
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
|
||||
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
|
||||
|
||||
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
|
||||
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
|
||||
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
|
||||
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
|
||||
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
|
||||
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
|
||||
|
||||
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 0, len(subs.Subscriptions))
|
||||
}
|
||||
|
||||
func TestScanSubscribersShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
|
||||
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 3, len(subs.Shared))
|
||||
}
|
||||
|
||||
func TestSelectSharedSubscriber(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110})
|
||||
index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 2, len(subs.Shared))
|
||||
require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c")
|
||||
require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c")
|
||||
require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3)
|
||||
require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1)
|
||||
subs.SelectShared()
|
||||
require.Len(t, subs.SharedSelected, 2)
|
||||
}
|
||||
|
||||
func TestMergeSharedSelected(t *testing.T) {
|
||||
s := &Subscribers{
|
||||
SharedSelected: map[string]packets.Subscription{
|
||||
"cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110},
|
||||
"cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111},
|
||||
},
|
||||
Subscriptions: map[string]packets.Subscription{
|
||||
"cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112},
|
||||
},
|
||||
}
|
||||
|
||||
s.MergeSharedSelected()
|
||||
|
||||
require.Equal(t, 2, len(s.Subscriptions))
|
||||
require.Contains(t, s.Subscriptions, "cl1")
|
||||
require.Contains(t, s.Subscriptions, "cl2")
|
||||
require.EqualValues(t, map[string]int{
|
||||
SharePrefix + "/tmp2/a/b/c": 111,
|
||||
"a/b/c": 112,
|
||||
}, s.Subscriptions["cl2"].Identifiers)
|
||||
}
|
||||
|
||||
func TestSubscribersFind(t *testing.T) {
|
||||
tt := []struct {
|
||||
filter string
|
||||
topic string
|
||||
matched bool
|
||||
}{
|
||||
{filter: "a", topic: "a", matched: true},
|
||||
{filter: "a/", topic: "a", matched: false},
|
||||
{filter: "a/", topic: "a/", matched: true},
|
||||
{filter: "/a", topic: "/a", matched: true},
|
||||
{filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "#", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2
|
||||
{filter: "trailing-end/#", topic: "trailing-end/", matched: true},
|
||||
{filter: "+/prefixed", topic: "/prefixed", matched: true},
|
||||
{filter: "+/+/#", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "path/to/", topic: "path/to/my/mqtt", matched: false},
|
||||
{filter: "#/stuff", topic: "path/to/my/mqtt", matched: false},
|
||||
{filter: "#", topic: "$SYS/info", matched: false},
|
||||
{filter: "$SYS/#", topic: "$SYS/info", matched: true},
|
||||
{filter: "+/info", topic: "$SYS/info", matched: false},
|
||||
}
|
||||
|
||||
for _, tx := range tt {
|
||||
t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: tx.filter})
|
||||
subs := index.Subscribers(tx.topic)
|
||||
require.Equal(t, tx.matched, len(subs.Subscriptions) == 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSubscribers(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"})
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"})
|
||||
index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"})
|
||||
index.Subscribe("cl3", packets.Subscription{Filter: "#"})
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribers("a/b/c")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPattern(t *testing.T) {
|
||||
payload := []byte("hello")
|
||||
fh := packets.FixedHeader{Type: packets.Publish, Retain: true}
|
||||
|
||||
pks := []packets.Packet{
|
||||
{TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "$SYS/info", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "asdf", Payload: payload, FixedHeader: fh},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
filter string
|
||||
len int
|
||||
}{
|
||||
{"a/b/c/d", 1},
|
||||
{"$SYS/+", 2},
|
||||
{"$SYS/#", 2},
|
||||
{"#", len(pks) - 2},
|
||||
{"a/b/c/+", 2},
|
||||
{"a/+/c/+", 2},
|
||||
{"+/+/+/d", 1},
|
||||
{"q/w/e/#", 1},
|
||||
{"+/+/+/+", 3},
|
||||
{"q/#", 2},
|
||||
{"asdf", 1},
|
||||
{"", 0},
|
||||
{"#", 6},
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
for _, pk := range pks {
|
||||
index.RetainMessage(pk)
|
||||
}
|
||||
|
||||
for _, tx := range tt {
|
||||
t.Run("filter:'"+tx.filter, func(t *testing.T) {
|
||||
messages := index.Messages(tx.filter)
|
||||
require.Equal(t, tx.len, len(messages))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessages(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "$SYS/info"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Messages("+/b/c/+")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewParticles(t *testing.T) {
|
||||
cl := newParticles()
|
||||
require.NotNil(t, cl.internal)
|
||||
}
|
||||
|
||||
func TestParticlesAdd(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
}
|
||||
|
||||
func TestParticlesGet(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
p.add(&particle{key: "b"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
require.Contains(t, p.internal, "b")
|
||||
|
||||
particle := p.get("a")
|
||||
require.NotNil(t, particle)
|
||||
require.Equal(t, "a", particle.key)
|
||||
}
|
||||
|
||||
func TestParticlesGetAll(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
p.add(&particle{key: "b"})
|
||||
p.add(&particle{key: "c"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
require.Contains(t, p.internal, "b")
|
||||
require.Contains(t, p.internal, "c")
|
||||
|
||||
particles := p.getAll()
|
||||
require.Len(t, particles, 3)
|
||||
}
|
||||
|
||||
func TestParticlesLen(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
p.add(&particle{key: "b"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
require.Contains(t, p.internal, "b")
|
||||
require.Equal(t, 2, p.len())
|
||||
}
|
||||
|
||||
func TestParticlesDelete(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
|
||||
p.delete("a")
|
||||
particle := p.get("a")
|
||||
require.Nil(t, particle)
|
||||
}
|
||||
|
||||
func TestIsValid(t *testing.T) {
|
||||
require.True(t, IsValidFilter("a/b/c", false))
|
||||
require.True(t, IsValidFilter("a/b//c", false))
|
||||
require.True(t, IsValidFilter("$SYS", false))
|
||||
require.True(t, IsValidFilter("$SYS/info", false))
|
||||
require.True(t, IsValidFilter("$sys/info", false))
|
||||
require.True(t, IsValidFilter("abc/#", false))
|
||||
require.False(t, IsValidFilter("", false))
|
||||
require.False(t, IsValidFilter(SharePrefix, false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/b+/", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/+", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/#", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/#/", false))
|
||||
require.False(t, IsValidFilter("a/#/c", false))
|
||||
}
|
||||
|
||||
func TestIsValidForPublish(t *testing.T) {
|
||||
require.True(t, IsValidFilter("", true))
|
||||
require.True(t, IsValidFilter("a/b/c", true))
|
||||
require.False(t, IsValidFilter("a/b/+/d", true))
|
||||
require.False(t, IsValidFilter("a/b/#", true))
|
||||
require.False(t, IsValidFilter("$SYS/info", true))
|
||||
}
|
||||
|
||||
func TestIsSharedFilter(t *testing.T) {
|
||||
require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c"))
|
||||
require.False(t, IsSharedFilter("a/b/c"))
|
||||
}
|
||||
|
||||
func TestNewInboundAliases(t *testing.T) {
|
||||
a := NewInboundTopicAliases(5)
|
||||
require.NotNil(t, a)
|
||||
require.NotNil(t, a.internal)
|
||||
require.Equal(t, uint16(5), a.maximum)
|
||||
}
|
||||
|
||||
func TestInboundAliasesSet(t *testing.T) {
|
||||
topic := "test"
|
||||
id := uint16(1)
|
||||
a := NewInboundTopicAliases(5)
|
||||
require.Equal(t, topic, a.Set(id, topic))
|
||||
require.Contains(t, a.internal, id)
|
||||
require.Equal(t, a.internal[id], topic)
|
||||
|
||||
require.Equal(t, topic, a.Set(id, ""))
|
||||
}
|
||||
|
||||
func TestInboundAliasesSetMaxZero(t *testing.T) {
|
||||
topic := "test"
|
||||
id := uint16(1)
|
||||
a := NewInboundTopicAliases(0)
|
||||
require.Equal(t, topic, a.Set(id, topic))
|
||||
require.NotContains(t, a.internal, id)
|
||||
}
|
||||
|
||||
func TestNewOutboundAliases(t *testing.T) {
|
||||
a := NewOutboundTopicAliases(5)
|
||||
require.NotNil(t, a)
|
||||
require.NotNil(t, a.internal)
|
||||
require.Equal(t, uint16(5), a.maximum)
|
||||
require.Equal(t, uint32(0), a.cursor)
|
||||
}
|
||||
|
||||
func TestOutboundAliasesSet(t *testing.T) {
|
||||
a := NewOutboundTopicAliases(3)
|
||||
n, ok := a.Set("t1")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(1), n)
|
||||
|
||||
n, ok = a.Set("t2")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(2), n)
|
||||
|
||||
n, ok = a.Set("t3")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(3), n)
|
||||
|
||||
n, ok = a.Set("t4")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), n)
|
||||
|
||||
n, ok = a.Set("t2")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint16(2), n)
|
||||
}
|
||||
|
||||
func TestOutboundAliasesSetMaxZero(t *testing.T) {
|
||||
topic := "test"
|
||||
a := NewOutboundTopicAliases(0)
|
||||
n, ok := a.Set(topic)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), n)
|
||||
}
|
||||
|
||||
func TestNewTopicAliases(t *testing.T) {
|
||||
a := NewTopicAliases(5)
|
||||
require.NotNil(t, a.Inbound)
|
||||
require.Equal(t, uint16(5), a.Inbound.maximum)
|
||||
require.NotNil(t, a.Outbound)
|
||||
require.Equal(t, uint16(5), a.Outbound.maximum)
|
||||
}
|
||||
1
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
1
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
|
|
@ -1 +0,0 @@
|
|||
language: go
|
||||
35
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
35
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
|
|
@ -1,35 +0,0 @@
|
|||
bbloom.go
|
||||
|
||||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
siphash.go
|
||||
|
||||
// https://github.com/dchest/siphash
|
||||
//
|
||||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
131
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
131
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
|
|
@ -1,131 +0,0 @@
|
|||
## bbloom: a bitset Bloom filter for go/golang
|
||||
===
|
||||
|
||||
[](http://travis-ci.org/AndreasBriese/bbloom)
|
||||
|
||||
package implements a fast bloom filter with real 'bitset' and JSONMarshal/JSONUnmarshal to store/reload the Bloom filter.
|
||||
|
||||
NOTE: the package uses unsafe.Pointer to set and read the bits from the bitset. If you're uncomfortable with using the unsafe package, please consider using my bloom filter package at github.com/AndreasBriese/bloom
|
||||
|
||||
===
|
||||
|
||||
changelog 11/2015: new thread safe methods AddTS(), HasTS(), AddIfNotHasTS() following a suggestion from Srdjan Marinovic (github @a-little-srdjan), who used this to code a bloomfilter cache.
|
||||
|
||||
This bloom filter was developed to strengthen a website-log database and was tested and optimized for this log-entry mask: "2014/%02i/%02i %02i:%02i:%02i /info.html".
|
||||
Nonetheless bbloom should work with any other form of entries.
|
||||
|
||||
~~Hash function is a modified Berkeley DB sdbm hash (to optimize for smaller strings). sdbm http://www.cse.yorku.ca/~oz/hash.html~~
|
||||
|
||||
Found sipHash (SipHash-2-4, a fast short-input PRF created by Jean-Philippe Aumasson and Daniel J. Bernstein.) to be about as fast. sipHash had been ported by Dimtry Chestnyk to Go (github.com/dchest/siphash )
|
||||
|
||||
Minimum hashset size is: 512 ([4]uint64; will be set automatically).
|
||||
|
||||
###install
|
||||
|
||||
```sh
|
||||
go get github.com/AndreasBriese/bbloom
|
||||
```
|
||||
|
||||
###test
|
||||
+ change to folder ../bbloom
|
||||
+ create wordlist in file "words.txt" (you might use `python permut.py`)
|
||||
+ run 'go test -bench=.' within the folder
|
||||
|
||||
```go
|
||||
go test -bench=.
|
||||
```
|
||||
|
||||
~~If you've installed the GOCONVEY TDD-framework http://goconvey.co/ you can run the tests automatically.~~
|
||||
|
||||
using go's testing framework now (have in mind that the op timing is related to 65536 operations of Add, Has, AddIfNotHas respectively)
|
||||
|
||||
### usage
|
||||
|
||||
after installation add
|
||||
|
||||
```go
|
||||
import (
|
||||
...
|
||||
"github.com/AndreasBriese/bbloom"
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
at your header. In the program use
|
||||
|
||||
```go
|
||||
// create a bloom filter for 65536 items and 1 % wrong-positive ratio
|
||||
bf := bbloom.New(float64(1<<16), float64(0.01))
|
||||
|
||||
// or
|
||||
// create a bloom filter with 650000 for 65536 items and 7 locs per hash explicitly
|
||||
// bf = bbloom.New(float64(650000), float64(7))
|
||||
// or
|
||||
bf = bbloom.New(650000.0, 7.0)
|
||||
|
||||
// add one item
|
||||
bf.Add([]byte("butter"))
|
||||
|
||||
// Number of elements added is exposed now
|
||||
// Note: ElemNum will not be included in JSON export (for compatability to older version)
|
||||
nOfElementsInFilter := bf.ElemNum
|
||||
|
||||
// check if item is in the filter
|
||||
isIn := bf.Has([]byte("butter")) // should be true
|
||||
isNotIn := bf.Has([]byte("Butter")) // should be false
|
||||
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added := bf.AddIfNotHas([]byte("butter")) // should be false because 'butter' is already in the set
|
||||
added = bf.AddIfNotHas([]byte("buTTer")) // should be true because 'buTTer' is new
|
||||
|
||||
// thread safe versions for concurrent use: AddTS, HasTS, AddIfNotHasTS
|
||||
// add one item
|
||||
bf.AddTS([]byte("peanutbutter"))
|
||||
// check if item is in the filter
|
||||
isIn = bf.HasTS([]byte("peanutbutter")) // should be true
|
||||
isNotIn = bf.HasTS([]byte("peanutButter")) // should be false
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added = bf.AddIfNotHasTS([]byte("butter")) // should be false because 'peanutbutter' is already in the set
|
||||
added = bf.AddIfNotHasTS([]byte("peanutbuTTer")) // should be true because 'penutbuTTer' is new
|
||||
|
||||
// convert to JSON ([]byte)
|
||||
Json := bf.JSONMarshal()
|
||||
|
||||
// bloomfilters Mutex is exposed for external un-/locking
|
||||
// i.e. mutex lock while doing JSON conversion
|
||||
bf.Mtx.Lock()
|
||||
Json = bf.JSONMarshal()
|
||||
bf.Mtx.Unlock()
|
||||
|
||||
// restore a bloom filter from storage
|
||||
bfNew := bbloom.JSONUnmarshal(Json)
|
||||
|
||||
isInNew := bfNew.Has([]byte("butter")) // should be true
|
||||
isNotInNew := bfNew.Has([]byte("Butter")) // should be false
|
||||
|
||||
```
|
||||
|
||||
to work with the bloom filter.
|
||||
|
||||
### why 'fast'?
|
||||
|
||||
It's about 3 times faster than William Fitzgeralds bitset bloom filter https://github.com/willf/bloom . And it is about so fast as my []bool set variant for Boom filters (see https://github.com/AndreasBriese/bloom ) but having a 8times smaller memory footprint:
|
||||
|
||||
|
||||
Bloom filter (filter size 524288, 7 hashlocs)
|
||||
github.com/AndreasBriese/bbloom 'Add' 65536 items (10 repetitions): 6595800 ns (100 ns/op)
|
||||
github.com/AndreasBriese/bbloom 'Has' 65536 items (10 repetitions): 5986600 ns (91 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Add' 65536 items (10 repetitions): 6304684 ns (96 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Has' 65536 items (10 repetitions): 6568663 ns (100 ns/op)
|
||||
|
||||
github.com/willf/bloom 'Add' 65536 items (10 repetitions): 24367224 ns (371 ns/op)
|
||||
github.com/willf/bloom 'Test' 65536 items (10 repetitions): 21881142 ns (333 ns/op)
|
||||
github.com/dataence/bloom/standard 'Add' 65536 items (10 repetitions): 23041644 ns (351 ns/op)
|
||||
github.com/dataence/bloom/standard 'Check' 65536 items (10 repetitions): 19153133 ns (292 ns/op)
|
||||
github.com/cabello/bloom 'Add' 65536 items (10 repetitions): 131921507 ns (2012 ns/op)
|
||||
github.com/cabello/bloom 'Contains' 65536 items (10 repetitions): 131108962 ns (2000 ns/op)
|
||||
|
||||
(on MBPro15 OSX10.8.5 i7 4Core 2.4Ghz)
|
||||
|
||||
|
||||
With 32bit bloom filters (bloom32) using modified sdbm, bloom32 does hashing with only 2 bit shifts, one xor and one substraction per byte. smdb is about as fast as fnv64a but gives less collisions with the dataset (see mask above). bloom.New(float64(10 * 1<<16),float64(7)) populated with 1<<16 random items from the dataset (see above) and tested against the rest results in less than 0.05% collisions.
|
||||
284
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
284
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
|
|
@ -1,284 +0,0 @@
|
|||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
// 2019/08/25 code revision to reduce unsafe use
|
||||
// Parts are adopted from the fork at ipfs/bbloom after performance rev by
|
||||
// Steve Allen (https://github.com/Stebalien)
|
||||
// (see https://github.com/ipfs/bbloom/blob/master/bbloom.go)
|
||||
// -> func Has
|
||||
// -> func set
|
||||
// -> func add
|
||||
|
||||
package bbloom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// helper
|
||||
// not needed anymore by Set
|
||||
// var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
func getSize(ui64 uint64) (size uint64, exponent uint64) {
|
||||
if ui64 < uint64(512) {
|
||||
ui64 = uint64(512)
|
||||
}
|
||||
size = uint64(1)
|
||||
for size < ui64 {
|
||||
size <<= 1
|
||||
exponent++
|
||||
}
|
||||
return size, exponent
|
||||
}
|
||||
|
||||
func calcSizeByWrongPositives(numEntries, wrongs float64) (uint64, uint64) {
|
||||
size := -1 * numEntries * math.Log(wrongs) / math.Pow(float64(0.69314718056), 2)
|
||||
locs := math.Ceil(float64(0.69314718056) * size / numEntries)
|
||||
return uint64(size), uint64(locs)
|
||||
}
|
||||
|
||||
// New
|
||||
// returns a new bloomfilter
|
||||
func New(params ...float64) (bloomfilter Bloom) {
|
||||
var entries, locs uint64
|
||||
if len(params) == 2 {
|
||||
if params[1] < 1 {
|
||||
entries, locs = calcSizeByWrongPositives(params[0], params[1])
|
||||
} else {
|
||||
entries, locs = uint64(params[0]), uint64(params[1])
|
||||
}
|
||||
} else {
|
||||
log.Fatal("usage: New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(3)) or New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(0.03))")
|
||||
}
|
||||
size, exponent := getSize(uint64(entries))
|
||||
bloomfilter = Bloom{
|
||||
Mtx: &sync.Mutex{},
|
||||
sizeExp: exponent,
|
||||
size: size - 1,
|
||||
setLocs: locs,
|
||||
shift: 64 - exponent,
|
||||
}
|
||||
bloomfilter.Size(size)
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// NewWithBoolset
|
||||
// takes a []byte slice and number of locs per entry
|
||||
// returns the bloomfilter with a bitset populated according to the input []byte
|
||||
func NewWithBoolset(bs *[]byte, locs uint64) (bloomfilter Bloom) {
|
||||
bloomfilter = New(float64(len(*bs)<<3), float64(locs))
|
||||
for i, b := range *bs {
|
||||
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bloomfilter.bitset[0])) + uintptr(i))) = b
|
||||
}
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// bloomJSONImExport
|
||||
// Im/Export structure used by JSONMarshal / JSONUnmarshal
|
||||
type bloomJSONImExport struct {
|
||||
FilterSet []byte
|
||||
SetLocs uint64
|
||||
}
|
||||
|
||||
// JSONUnmarshal
|
||||
// takes JSON-Object (type bloomJSONImExport) as []bytes
|
||||
// returns Bloom object
|
||||
func JSONUnmarshal(dbData []byte) Bloom {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
json.Unmarshal(dbData, &bloomImEx)
|
||||
buf := bytes.NewBuffer(bloomImEx.FilterSet)
|
||||
bs := buf.Bytes()
|
||||
bf := NewWithBoolset(&bs, bloomImEx.SetLocs)
|
||||
return bf
|
||||
}
|
||||
|
||||
//
|
||||
// Bloom filter
|
||||
type Bloom struct {
|
||||
Mtx *sync.Mutex
|
||||
ElemNum uint64
|
||||
bitset []uint64
|
||||
sizeExp uint64
|
||||
size uint64
|
||||
setLocs uint64
|
||||
shift uint64
|
||||
}
|
||||
|
||||
// <--- http://www.cse.yorku.ca/~oz/hash.html
|
||||
// modified Berkeley DB Hash (32bit)
|
||||
// hash is casted to l, h = 16bit fragments
|
||||
// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.shift >> bl.shift
|
||||
// return l, h
|
||||
// }
|
||||
|
||||
// Update: found sipHash of Jean-Philippe Aumasson & Daniel J. Bernstein to be even faster than absdbm()
|
||||
// https://131002.net/siphash/
|
||||
// siphash was implemented for Go by Dmitry Chestnykh https://github.com/dchest/siphash
|
||||
|
||||
// Add
|
||||
// set the bit(s) for entry; Adds an entry to the Bloom filter
|
||||
func (bl *Bloom) Add(entry []byte) {
|
||||
l, h := bl.sipHash(entry)
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
bl.set((h + i*l) & bl.size)
|
||||
bl.ElemNum++
|
||||
}
|
||||
}
|
||||
|
||||
// AddTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) AddTS(entry []byte) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
bl.Add(entry)
|
||||
}
|
||||
|
||||
// Has
|
||||
// check if bit(s) for entry is/are set
|
||||
// returns true if the entry was added to the Bloom Filter
|
||||
func (bl Bloom) Has(entry []byte) bool {
|
||||
l, h := bl.sipHash(entry)
|
||||
res := true
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
res = res && bl.isSet((h+i*l)&bl.size)
|
||||
// https://github.com/ipfs/bbloom/commit/84e8303a9bfb37b2658b85982921d15bbb0fecff
|
||||
// // Branching here (early escape) is not worth it
|
||||
// // This is my conclusion from benchmarks
|
||||
// // (prevents loop unrolling)
|
||||
// switch bl.IsSet((h + i*l) & bl.size) {
|
||||
// case false:
|
||||
// return false
|
||||
// }
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// HasTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) HasTS(entry []byte) bool {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.Has(entry)
|
||||
}
|
||||
|
||||
// AddIfNotHas
|
||||
// Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl Bloom) AddIfNotHas(entry []byte) (added bool) {
|
||||
if bl.Has(entry) {
|
||||
return added
|
||||
}
|
||||
bl.Add(entry)
|
||||
return true
|
||||
}
|
||||
|
||||
// AddIfNotHasTS
|
||||
// Tread safe: Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl *Bloom) AddIfNotHasTS(entry []byte) (added bool) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.AddIfNotHas(entry)
|
||||
}
|
||||
|
||||
// Size
|
||||
// make Bloom filter with as bitset of size sz
|
||||
func (bl *Bloom) Size(sz uint64) {
|
||||
bl.bitset = make([]uint64, sz>>6)
|
||||
}
|
||||
|
||||
// Clear
|
||||
// resets the Bloom filter
|
||||
func (bl *Bloom) Clear() {
|
||||
bs := bl.bitset
|
||||
for i := range bs {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
// set the bit[idx] of bitsit
|
||||
func (bl *Bloom) set(idx uint64) {
|
||||
// ommit unsafe
|
||||
// *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3))) |= mask[idx%8]
|
||||
bl.bitset[idx>>6] |= 1 << (idx % 64)
|
||||
}
|
||||
|
||||
// IsSet
|
||||
// check if bit[idx] of bitset is set
|
||||
// returns true/false
|
||||
func (bl *Bloom) isSet(idx uint64) bool {
|
||||
// ommit unsafe
|
||||
// return (((*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)))) >> (idx % 8)) & 1) == 1
|
||||
return bl.bitset[idx>>6]&(1<<(idx%64)) != 0
|
||||
}
|
||||
|
||||
// JSONMarshal
|
||||
// returns JSON-object (type bloomJSONImExport) as []byte
|
||||
func (bl Bloom) JSONMarshal() []byte {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
bloomImEx.SetLocs = uint64(bl.setLocs)
|
||||
bloomImEx.FilterSet = make([]byte, len(bl.bitset)<<3)
|
||||
for i := range bloomImEx.FilterSet {
|
||||
bloomImEx.FilterSet[i] = *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[0])) + uintptr(i)))
|
||||
}
|
||||
data, err := json.Marshal(bloomImEx)
|
||||
if err != nil {
|
||||
log.Fatal("json.Marshal failed: ", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// // alternative hashFn
|
||||
// func (bl Bloom) fnv64a(b *[]byte) (l, h uint64) {
|
||||
// h64 := fnv.New64a()
|
||||
// h64.Write(*b)
|
||||
// hash := h64.Sum64()
|
||||
// h = hash >> 32
|
||||
// l = hash << 32 >> 32
|
||||
// return l, h
|
||||
// }
|
||||
//
|
||||
// // <-- http://partow.net/programming/hashfunctions/index.html
|
||||
// // citation: An algorithm proposed by Donald E. Knuth in The Art Of Computer Programming Volume 3,
|
||||
// // under the topic of sorting and search chapter 6.4.
|
||||
// // modified to fit with boolset-length
|
||||
// func (bl Bloom) DEKHash(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = ((hash << 5) ^ (hash >> bl.shift)) ^ uint64(c)
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.sizeExp >> bl.sizeExp
|
||||
// return l, h
|
||||
// }
|
||||
225
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
225
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
|
|
@ -1,225 +0,0 @@
|
|||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
|
||||
package bbloom
|
||||
|
||||
// Hash returns the 64-bit SipHash-2-4 of the given byte slice with two 64-bit
|
||||
// parts of 128-bit key: k0 and k1.
|
||||
func (bl Bloom) sipHash(p []byte) (l, h uint64) {
|
||||
// Initialization.
|
||||
v0 := uint64(8317987320269560794) // k0 ^ 0x736f6d6570736575
|
||||
v1 := uint64(7237128889637516672) // k1 ^ 0x646f72616e646f6d
|
||||
v2 := uint64(7816392314733513934) // k0 ^ 0x6c7967656e657261
|
||||
v3 := uint64(8387220255325274014) // k1 ^ 0x7465646279746573
|
||||
t := uint64(len(p)) << 56
|
||||
|
||||
// Compression.
|
||||
for len(p) >= 8 {
|
||||
|
||||
m := uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
|
||||
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
|
||||
|
||||
v3 ^= m
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= m
|
||||
p = p[8:]
|
||||
}
|
||||
|
||||
// Compress last block.
|
||||
switch len(p) {
|
||||
case 7:
|
||||
t |= uint64(p[6]) << 48
|
||||
fallthrough
|
||||
case 6:
|
||||
t |= uint64(p[5]) << 40
|
||||
fallthrough
|
||||
case 5:
|
||||
t |= uint64(p[4]) << 32
|
||||
fallthrough
|
||||
case 4:
|
||||
t |= uint64(p[3]) << 24
|
||||
fallthrough
|
||||
case 3:
|
||||
t |= uint64(p[2]) << 16
|
||||
fallthrough
|
||||
case 2:
|
||||
t |= uint64(p[1]) << 8
|
||||
fallthrough
|
||||
case 1:
|
||||
t |= uint64(p[0])
|
||||
}
|
||||
|
||||
v3 ^= t
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= t
|
||||
|
||||
// Finalization.
|
||||
v2 ^= 0xff
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 3.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 4.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// return v0 ^ v1 ^ v2 ^ v3
|
||||
|
||||
hash := v0 ^ v1 ^ v2 ^ v3
|
||||
h = hash >> bl.shift
|
||||
l = hash << bl.shift >> bl.shift
|
||||
return l, h
|
||||
|
||||
}
|
||||
24
backend/services/mochi/vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
24
backend/services/mochi/vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
|
|
@ -1,24 +0,0 @@
|
|||
This is free and unencumbered software released into the public domain.
|
||||
|
||||
Anyone is free to copy, modify, publish, use, compile, sell, or
|
||||
distribute this software, either in source code form or as a compiled
|
||||
binary, for any purpose, commercial or non-commercial, and by any
|
||||
means.
|
||||
|
||||
In jurisdictions that recognize copyright laws, the author or authors
|
||||
of this software dedicate any and all copyright interest in the
|
||||
software to the public domain. We make this dedication for the benefit
|
||||
of the public at large and to the detriment of our heirs and
|
||||
successors. We intend this dedication to be an overt act of
|
||||
relinquishment in perpetuity of all present and future rights to this
|
||||
software under copyright law.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
For more information, please refer to <http://unlicense.org/>
|
||||
7
backend/services/mochi/vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
7
backend/services/mochi/vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
|
|
@ -1,7 +0,0 @@
|
|||
# gopher-json [](https://godoc.org/layeh.com/gopher-json)
|
||||
|
||||
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
|
||||
|
||||
## License
|
||||
|
||||
Public domain
|
||||
33
backend/services/mochi/vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
33
backend/services/mochi/vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
|
|
@ -1,33 +0,0 @@
|
|||
// Package json is a simple JSON encoder/decoder for gopher-lua.
|
||||
//
|
||||
// Documentation
|
||||
//
|
||||
// The following functions are exposed by the library:
|
||||
// decode(string): Decodes a JSON string. Returns nil and an error string if
|
||||
// the string could not be decoded.
|
||||
// encode(value): Encodes a value into a JSON string. Returns nil and an error
|
||||
// string if the value could not be encoded.
|
||||
//
|
||||
// The following types are supported:
|
||||
//
|
||||
// Lua | JSON
|
||||
// ---------+-----
|
||||
// nil | null
|
||||
// number | number
|
||||
// string | string
|
||||
// table | object: when table is non-empty and has only string keys
|
||||
// | array: when table is empty, or has only sequential numeric keys
|
||||
// | starting from 1
|
||||
//
|
||||
// Attempting to encode any other Lua type will result in an error.
|
||||
//
|
||||
// Example
|
||||
//
|
||||
// Below is an example usage of the library:
|
||||
// import (
|
||||
// luajson "layeh.com/gopher-json"
|
||||
// )
|
||||
//
|
||||
// L := lua.NewState()
|
||||
// luajson.Preload(s)
|
||||
package json
|
||||
189
backend/services/mochi/vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
189
backend/services/mochi/vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
|
|
@ -1,189 +0,0 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// Preload adds json to the given Lua state's package.preload table. After it
|
||||
// has been preloaded, it can be loaded using require:
|
||||
//
|
||||
// local json = require("json")
|
||||
func Preload(L *lua.LState) {
|
||||
L.PreloadModule("json", Loader)
|
||||
}
|
||||
|
||||
// Loader is the module loader function.
|
||||
func Loader(L *lua.LState) int {
|
||||
t := L.NewTable()
|
||||
L.SetFuncs(t, api)
|
||||
L.Push(t)
|
||||
return 1
|
||||
}
|
||||
|
||||
var api = map[string]lua.LGFunction{
|
||||
"decode": apiDecode,
|
||||
"encode": apiEncode,
|
||||
}
|
||||
|
||||
func apiDecode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to decode"), 1)
|
||||
return 0
|
||||
}
|
||||
str := L.CheckString(1)
|
||||
|
||||
value, err := Decode(L, []byte(str))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(value)
|
||||
return 1
|
||||
}
|
||||
|
||||
func apiEncode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to encode"), 1)
|
||||
return 0
|
||||
}
|
||||
value := L.CheckAny(1)
|
||||
|
||||
data, err := Encode(value)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(lua.LString(string(data)))
|
||||
return 1
|
||||
}
|
||||
|
||||
var (
|
||||
errNested = errors.New("cannot encode recursively nested tables to JSON")
|
||||
errSparseArray = errors.New("cannot encode sparse array")
|
||||
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
|
||||
)
|
||||
|
||||
type invalidTypeError lua.LValueType
|
||||
|
||||
func (i invalidTypeError) Error() string {
|
||||
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
|
||||
}
|
||||
|
||||
// Encode returns the JSON encoding of value.
|
||||
func Encode(value lua.LValue) ([]byte, error) {
|
||||
return json.Marshal(jsonValue{
|
||||
LValue: value,
|
||||
visited: make(map[*lua.LTable]bool),
|
||||
})
|
||||
}
|
||||
|
||||
type jsonValue struct {
|
||||
lua.LValue
|
||||
visited map[*lua.LTable]bool
|
||||
}
|
||||
|
||||
func (j jsonValue) MarshalJSON() (data []byte, err error) {
|
||||
switch converted := j.LValue.(type) {
|
||||
case lua.LBool:
|
||||
data, err = json.Marshal(bool(converted))
|
||||
case lua.LNumber:
|
||||
data, err = json.Marshal(float64(converted))
|
||||
case *lua.LNilType:
|
||||
data = []byte(`null`)
|
||||
case lua.LString:
|
||||
data, err = json.Marshal(string(converted))
|
||||
case *lua.LTable:
|
||||
if j.visited[converted] {
|
||||
return nil, errNested
|
||||
}
|
||||
j.visited[converted] = true
|
||||
|
||||
key, value := converted.Next(lua.LNil)
|
||||
|
||||
switch key.Type() {
|
||||
case lua.LTNil: // empty table
|
||||
data = []byte(`[]`)
|
||||
case lua.LTNumber:
|
||||
arr := make([]jsonValue, 0, converted.Len())
|
||||
expectedKey := lua.LNumber(1)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTNumber {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
if expectedKey != key {
|
||||
err = errSparseArray
|
||||
return
|
||||
}
|
||||
arr = append(arr, jsonValue{value, j.visited})
|
||||
expectedKey++
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(arr)
|
||||
case lua.LTString:
|
||||
obj := make(map[string]jsonValue)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTString {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
obj[key.String()] = jsonValue{value, j.visited}
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(obj)
|
||||
default:
|
||||
err = errInvalidKeys
|
||||
}
|
||||
default:
|
||||
err = invalidTypeError(j.LValue.Type())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode converts the JSON encoded data to Lua values.
|
||||
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
|
||||
var value interface{}
|
||||
err := json.Unmarshal(data, &value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeValue(L, value), nil
|
||||
}
|
||||
|
||||
// DecodeValue converts the value to a Lua value.
|
||||
//
|
||||
// This function only converts values that the encoding/json package decodes to.
|
||||
// All other values will return lua.LNil.
|
||||
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
|
||||
switch converted := value.(type) {
|
||||
case bool:
|
||||
return lua.LBool(converted)
|
||||
case float64:
|
||||
return lua.LNumber(converted)
|
||||
case string:
|
||||
return lua.LString(converted)
|
||||
case json.Number:
|
||||
return lua.LString(converted)
|
||||
case []interface{}:
|
||||
arr := L.CreateTable(len(converted), 0)
|
||||
for _, item := range converted {
|
||||
arr.Append(DecodeValue(L, item))
|
||||
}
|
||||
return arr
|
||||
case map[string]interface{}:
|
||||
tbl := L.CreateTable(0, len(converted))
|
||||
for key, item := range converted {
|
||||
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
|
||||
}
|
||||
return tbl
|
||||
case nil:
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
return lua.LNil
|
||||
}
|
||||
6
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
6
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
|
|
@ -1,6 +0,0 @@
|
|||
/integration/redis_src/
|
||||
/integration/dump.rdb
|
||||
*.swp
|
||||
/integration/nodes.conf
|
||||
.idea/
|
||||
miniredis.iml
|
||||
225
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
225
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
|
|
@ -1,225 +0,0 @@
|
|||
## Changelog
|
||||
|
||||
|
||||
### v2.23.0
|
||||
|
||||
- basic INFO support (thanks @kirill-a-belov)
|
||||
- support COUNT in SSCAN (thanks @Abdi-dd)
|
||||
- test and support Go 1.19
|
||||
- support LPOS (thanks @ianstarz)
|
||||
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
|
||||
|
||||
|
||||
### v2.22.0
|
||||
|
||||
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
|
||||
- fix invalid resposne of COMMAND (thanks @zsh1995)
|
||||
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
|
||||
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
|
||||
|
||||
|
||||
### v2.21.0
|
||||
|
||||
- support for GETEX (thanks @dntj)
|
||||
- support for GT and LT in ZADD (thanks @lsgndln)
|
||||
- support for XAUTOCLAIM (thanks @randall-fulton)
|
||||
|
||||
|
||||
### v2.20.0
|
||||
|
||||
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
|
||||
|
||||
|
||||
### v2.19.0
|
||||
|
||||
- support for TYPE in SCAN (thanks @0xDiddi)
|
||||
- update BITPOS (thanks @dirkm)
|
||||
- fix a lua redis.call() return value (thanks @mpetronic)
|
||||
- update ZRANGE (thanks @valdemarpereira)
|
||||
|
||||
|
||||
### v2.18.0
|
||||
|
||||
- support for ZUNION (thanks @propan)
|
||||
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
|
||||
- support for LMOVE (thanks @btwear)
|
||||
|
||||
|
||||
### v2.17.0
|
||||
|
||||
- added miniredis.RunT(t)
|
||||
|
||||
|
||||
### v2.16.1
|
||||
|
||||
- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
|
||||
- fix exclusive ranges in XRANGE (thanks @joseotoro)
|
||||
|
||||
|
||||
### v2.16.0
|
||||
|
||||
- simplify some code (thanks @zonque)
|
||||
- support for EXAT/PXAT in SET
|
||||
- support for XTRIM (thanks @joseotoro)
|
||||
- support for ZRANDMEMBER
|
||||
- support for redis.log() in lua (thanks @dirkm)
|
||||
|
||||
|
||||
### v2.15.2
|
||||
|
||||
- Fix race condition in blocking code (thanks @zonque and @robx)
|
||||
- XREAD accepts '$' as ID (thanks @bradengroom)
|
||||
|
||||
|
||||
### v2.15.1
|
||||
|
||||
- EVAL should cache the script (thanks @guoshimin)
|
||||
|
||||
|
||||
### v2.15.0
|
||||
|
||||
- target redis 6.2 and added new args to various commands
|
||||
- support for all hyperlog commands (thanks @ilbaktin)
|
||||
- support for GETDEL (thanks @wszaranski)
|
||||
|
||||
|
||||
### v2.14.5
|
||||
|
||||
- added XPENDING
|
||||
- support for BLOCK option in XREAD and XREADGROUP
|
||||
|
||||
|
||||
### v2.14.4
|
||||
|
||||
- fix BITPOS error (thanks @xiaoyuzdy)
|
||||
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
|
||||
- fix empty EXEC return type (thanks @ashanbrown)
|
||||
- fix XDEL (thanks @svakili and @yvesf)
|
||||
- fix FLUSHALL for streams (thanks @svakili)
|
||||
|
||||
|
||||
### v2.14.3
|
||||
|
||||
- fix problem where Lua code didn't set the selected DB
|
||||
- update to redis 6.0.10 (thanks @lazappa)
|
||||
|
||||
|
||||
### v2.14.2
|
||||
|
||||
- update LUA dependency
|
||||
- deal with (p)unsubscribe when there are no channels
|
||||
|
||||
|
||||
### v2.14.1
|
||||
|
||||
- mod tidy
|
||||
|
||||
|
||||
### v2.14.0
|
||||
|
||||
- support for HELLO and the RESP3 protocol
|
||||
- KEEPTTL in SET (thanks @johnpena)
|
||||
|
||||
|
||||
### v2.13.3
|
||||
|
||||
- support Go 1.14 and 1.15
|
||||
- update the `Check...()` methods
|
||||
- support for XREAD (thanks @pieterlexis)
|
||||
|
||||
|
||||
### v2.13.2
|
||||
|
||||
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
|
||||
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
|
||||
- changed unit and integration tests to compare raw payloads, not parsed payloads
|
||||
- remove "redigo" dependency
|
||||
|
||||
|
||||
### v2.13.1
|
||||
|
||||
- added HSTRLEN
|
||||
- minimal support for ACL users in AUTH
|
||||
|
||||
|
||||
### v2.13.0
|
||||
|
||||
- added RunTLS(...)
|
||||
- added SetError(...)
|
||||
|
||||
|
||||
### v2.12.0
|
||||
|
||||
- redis 6
|
||||
- Lua json update (thanks @gsmith85)
|
||||
- CLUSTER commands (thanks @kratisto)
|
||||
- fix TOUCH
|
||||
- fix a shutdown race condition
|
||||
|
||||
|
||||
### v2.11.4
|
||||
|
||||
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
|
||||
|
||||
|
||||
### v2.11.3
|
||||
|
||||
- support for TOUCH (thanks @cleroux)
|
||||
- support for cluster and stream commands (thanks @kak-tus)
|
||||
|
||||
|
||||
### v2.11.2
|
||||
|
||||
- make sure Lua code is executed concurrently
|
||||
- add command GEORADIUSBYMEMBER (thanks @kyeett)
|
||||
|
||||
|
||||
### v2.11.1
|
||||
|
||||
- globals protection for Lua code (thanks @vk-outreach)
|
||||
- HSET update (thanks @carlgreen)
|
||||
- fix BLPOP block on shutdown (thanks @Asalle)
|
||||
|
||||
|
||||
### v2.11.0
|
||||
|
||||
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
|
||||
- added GEODIST
|
||||
- improved precision for geohashes, closer to what real redis does
|
||||
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
|
||||
|
||||
|
||||
### v2.10.1
|
||||
|
||||
- added m.Server()
|
||||
|
||||
|
||||
### v2.10.0
|
||||
|
||||
- added UNLINK
|
||||
- fix DEL zero-argument case
|
||||
- cleanup some direct access commands
|
||||
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
|
||||
|
||||
|
||||
### v2.9.1
|
||||
|
||||
- fix issue with ZRANGEBYLEX
|
||||
- fix issue with BRPOPLPUSH and direct access
|
||||
|
||||
|
||||
### v2.9.0
|
||||
|
||||
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
|
||||
- fix messages generated by PSUBSCRIBE
|
||||
- optional internal seed (thanks @zikaeroh)
|
||||
|
||||
|
||||
### v2.8.0
|
||||
|
||||
Proper `v2` in go.mod.
|
||||
|
||||
|
||||
### older
|
||||
|
||||
See https://github.com/alicebob/miniredis/releases for the full changelog
|
||||
21
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/LICENSE
generated
vendored
21
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/LICENSE
generated
vendored
|
|
@ -1,21 +0,0 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Harmen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
12
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/Makefile
generated
vendored
12
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/Makefile
generated
vendored
|
|
@ -1,12 +0,0 @@
|
|||
.PHONY: all test testrace int
|
||||
|
||||
all: test
|
||||
|
||||
test:
|
||||
go test ./...
|
||||
|
||||
testrace:
|
||||
go test -race ./...
|
||||
|
||||
int:
|
||||
${MAKE} -C integration all
|
||||
333
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/README.md
generated
vendored
333
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/README.md
generated
vendored
|
|
@ -1,333 +0,0 @@
|
|||
# Miniredis
|
||||
|
||||
Pure Go Redis test server, used in Go unittests.
|
||||
|
||||
|
||||
##
|
||||
|
||||
Sometimes you want to test code which uses Redis, without making it a full-blown
|
||||
integration test.
|
||||
Miniredis implements (parts of) the Redis server, to be used in unittests. It
|
||||
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
|
||||
|
||||
It saves you from using mock code, and since the redis server lives in the
|
||||
test process you can query for values directly, without going through the server
|
||||
stack.
|
||||
|
||||
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
|
||||
|
||||
Be sure to import v2:
|
||||
```
|
||||
import "github.com/alicebob/miniredis/v2"
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
Implemented commands:
|
||||
|
||||
- Connection (complete)
|
||||
- AUTH -- see RequireAuth()
|
||||
- ECHO
|
||||
- HELLO -- see RequireUserAuth()
|
||||
- PING
|
||||
- SELECT
|
||||
- SWAPDB
|
||||
- QUIT
|
||||
- Key
|
||||
- COPY
|
||||
- DEL
|
||||
- EXISTS
|
||||
- EXPIRE
|
||||
- EXPIREAT
|
||||
- KEYS
|
||||
- MOVE
|
||||
- PERSIST
|
||||
- PEXPIRE
|
||||
- PEXPIREAT
|
||||
- PTTL
|
||||
- RENAME
|
||||
- RENAMENX
|
||||
- RANDOMKEY -- see m.Seed(...)
|
||||
- SCAN
|
||||
- TOUCH
|
||||
- TTL
|
||||
- TYPE
|
||||
- UNLINK
|
||||
- Transactions (complete)
|
||||
- DISCARD
|
||||
- EXEC
|
||||
- MULTI
|
||||
- UNWATCH
|
||||
- WATCH
|
||||
- Server
|
||||
- DBSIZE
|
||||
- FLUSHALL
|
||||
- FLUSHDB
|
||||
- TIME -- returns time.Now() or value set by SetTime()
|
||||
- COMMAND -- partly
|
||||
- INFO -- partly, returns only "clients" section with one field "connected_clients"
|
||||
- String keys (complete)
|
||||
- APPEND
|
||||
- BITCOUNT
|
||||
- BITOP
|
||||
- BITPOS
|
||||
- DECR
|
||||
- DECRBY
|
||||
- GET
|
||||
- GETBIT
|
||||
- GETRANGE
|
||||
- GETSET
|
||||
- GETDEL
|
||||
- GETEX
|
||||
- INCR
|
||||
- INCRBY
|
||||
- INCRBYFLOAT
|
||||
- MGET
|
||||
- MSET
|
||||
- MSETNX
|
||||
- PSETEX
|
||||
- SET
|
||||
- SETBIT
|
||||
- SETEX
|
||||
- SETNX
|
||||
- SETRANGE
|
||||
- STRLEN
|
||||
- Hash keys (complete)
|
||||
- HDEL
|
||||
- HEXISTS
|
||||
- HGET
|
||||
- HGETALL
|
||||
- HINCRBY
|
||||
- HINCRBYFLOAT
|
||||
- HKEYS
|
||||
- HLEN
|
||||
- HMGET
|
||||
- HMSET
|
||||
- HSET
|
||||
- HSETNX
|
||||
- HSTRLEN
|
||||
- HVALS
|
||||
- HSCAN
|
||||
- List keys (complete)
|
||||
- BLPOP
|
||||
- BRPOP
|
||||
- BRPOPLPUSH
|
||||
- LINDEX
|
||||
- LINSERT
|
||||
- LLEN
|
||||
- LPOP
|
||||
- LPUSH
|
||||
- LPUSHX
|
||||
- LRANGE
|
||||
- LREM
|
||||
- LSET
|
||||
- LTRIM
|
||||
- RPOP
|
||||
- RPOPLPUSH
|
||||
- RPUSH
|
||||
- RPUSHX
|
||||
- LMOVE
|
||||
- Pub/Sub (complete)
|
||||
- PSUBSCRIBE
|
||||
- PUBLISH
|
||||
- PUBSUB
|
||||
- PUNSUBSCRIBE
|
||||
- SUBSCRIBE
|
||||
- UNSUBSCRIBE
|
||||
- Set keys (complete)
|
||||
- SADD
|
||||
- SCARD
|
||||
- SDIFF
|
||||
- SDIFFSTORE
|
||||
- SINTER
|
||||
- SINTERSTORE
|
||||
- SISMEMBER
|
||||
- SMEMBERS
|
||||
- SMOVE
|
||||
- SPOP -- see m.Seed(...)
|
||||
- SRANDMEMBER -- see m.Seed(...)
|
||||
- SREM
|
||||
- SUNION
|
||||
- SUNIONSTORE
|
||||
- SSCAN
|
||||
- Sorted Set keys (complete)
|
||||
- ZADD
|
||||
- ZCARD
|
||||
- ZCOUNT
|
||||
- ZINCRBY
|
||||
- ZINTERSTORE
|
||||
- ZLEXCOUNT
|
||||
- ZPOPMIN
|
||||
- ZPOPMAX
|
||||
- ZRANDMEMBER
|
||||
- ZRANGE
|
||||
- ZRANGEBYLEX
|
||||
- ZRANGEBYSCORE
|
||||
- ZRANK
|
||||
- ZREM
|
||||
- ZREMRANGEBYLEX
|
||||
- ZREMRANGEBYRANK
|
||||
- ZREMRANGEBYSCORE
|
||||
- ZREVRANGE
|
||||
- ZREVRANGEBYLEX
|
||||
- ZREVRANGEBYSCORE
|
||||
- ZREVRANK
|
||||
- ZSCORE
|
||||
- ZUNION
|
||||
- ZUNIONSTORE
|
||||
- ZSCAN
|
||||
- Stream keys
|
||||
- XACK
|
||||
- XADD
|
||||
- XAUTOCLAIM
|
||||
- XCLAIM
|
||||
- XDEL
|
||||
- XGROUP CREATE
|
||||
- XGROUP CREATECONSUMER
|
||||
- XGROUP DESTROY
|
||||
- XGROUP DELCONSUMER
|
||||
- XINFO STREAM -- partly
|
||||
- XINFO GROUPS
|
||||
- XINFO CONSUMERS -- partly
|
||||
- XLEN
|
||||
- XRANGE
|
||||
- XREAD
|
||||
- XREADGROUP
|
||||
- XREVRANGE
|
||||
- XPENDING
|
||||
- XTRIM
|
||||
- Scripting
|
||||
- EVAL
|
||||
- EVALSHA
|
||||
- SCRIPT LOAD
|
||||
- SCRIPT EXISTS
|
||||
- SCRIPT FLUSH
|
||||
- GEO
|
||||
- GEOADD
|
||||
- GEODIST
|
||||
- ~~GEOHASH~~
|
||||
- GEOPOS
|
||||
- GEORADIUS
|
||||
- GEORADIUS_RO
|
||||
- GEORADIUSBYMEMBER
|
||||
- GEORADIUSBYMEMBER_RO
|
||||
- Cluster
|
||||
- CLUSTER SLOTS
|
||||
- CLUSTER KEYSLOT
|
||||
- CLUSTER NODES
|
||||
- HyperLogLog (complete)
|
||||
- PFADD
|
||||
- PFCOUNT
|
||||
- PFMERGE
|
||||
|
||||
|
||||
## TTLs, key expiration, and time
|
||||
|
||||
Since miniredis is intended to be used in unittests TTLs don't decrease
|
||||
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
|
||||
key. It will return 0 when no TTL is set.
|
||||
|
||||
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
|
||||
0 will be removed.
|
||||
|
||||
EXPIREAT and PEXPIREAT values will be
|
||||
converted to a duration. For that you can either set m.SetTime(t) to use that
|
||||
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
|
||||
which case time.Now() will be used.
|
||||
|
||||
SetTime() also sets the value returned by TIME, which defaults to time.Now().
|
||||
It is not updated by FastForward, only by SetTime.
|
||||
|
||||
## Randomness and Seed()
|
||||
|
||||
Miniredis will use `math/rand`'s global RNG for randomness unless a seed is
|
||||
provided by calling `m.Seed(...)`. If a seed is provided, then miniredis will
|
||||
use its own RNG based on that seed.
|
||||
|
||||
Commands which use randomness are: RANDOMKEY, SPOP, and SRANDMEMBER.
|
||||
|
||||
## Example
|
||||
|
||||
``` Go
|
||||
|
||||
import (
|
||||
...
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
...
|
||||
)
|
||||
|
||||
func TestSomething(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
|
||||
// Optionally set some keys your code expects:
|
||||
s.Set("foo", "bar")
|
||||
s.HSet("some", "other", "key")
|
||||
|
||||
// Run your code and see if it behaves.
|
||||
// An example using the redigo library from "github.com/gomodule/redigo/redis":
|
||||
c, err := redis.Dial("tcp", s.Addr())
|
||||
_, err = c.Do("SET", "foo", "bar")
|
||||
|
||||
// Optionally check values in redis...
|
||||
if got, err := s.Get("foo"); err != nil || got != "bar" {
|
||||
t.Error("'foo' has the wrong value")
|
||||
}
|
||||
// ... or use a helper for that:
|
||||
s.CheckGet(t, "foo", "bar")
|
||||
|
||||
// TTL and expiration:
|
||||
s.Set("foo", "bar")
|
||||
s.SetTTL("foo", 10*time.Second)
|
||||
s.FastForward(11 * time.Second)
|
||||
if s.Exists("foo") {
|
||||
t.Fatal("'foo' should not have existed anymore")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Not supported
|
||||
|
||||
Commands which will probably not be implemented:
|
||||
|
||||
- CLUSTER (all)
|
||||
- ~~CLUSTER *~~
|
||||
- ~~READONLY~~
|
||||
- ~~READWRITE~~
|
||||
- Key
|
||||
- ~~DUMP~~
|
||||
- ~~MIGRATE~~
|
||||
- ~~OBJECT~~
|
||||
- ~~RESTORE~~
|
||||
- ~~WAIT~~
|
||||
- Scripting
|
||||
- ~~SCRIPT DEBUG~~
|
||||
- ~~SCRIPT KILL~~
|
||||
- Server
|
||||
- ~~BGSAVE~~
|
||||
- ~~BGWRITEAOF~~
|
||||
- ~~CLIENT *~~
|
||||
- ~~CONFIG *~~
|
||||
- ~~DEBUG *~~
|
||||
- ~~LASTSAVE~~
|
||||
- ~~MONITOR~~
|
||||
- ~~ROLE~~
|
||||
- ~~SAVE~~
|
||||
- ~~SHUTDOWN~~
|
||||
- ~~SLAVEOF~~
|
||||
- ~~SLOWLOG~~
|
||||
- ~~SYNC~~
|
||||
|
||||
|
||||
## &c.
|
||||
|
||||
Integration tests are run against Redis 6.2.6. The [./integration](./integration/) subdir
|
||||
compares miniredis against a real redis instance.
|
||||
|
||||
The Redis 6 RESP3 protocol is supported. If there are problems, please open
|
||||
an issue.
|
||||
|
||||
If you want to test Redis Sentinel have a look at [minisentinel](https://github.com/Bose/minisentinel).
|
||||
|
||||
A changelog is kept at [CHANGELOG.md](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md).
|
||||
|
||||
[](https://pkg.go.dev/github.com/alicebob/miniredis/v2)
|
||||
63
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/check.go
generated
vendored
63
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/check.go
generated
vendored
|
|
@ -1,63 +0,0 @@
|
|||
package miniredis
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// T is implemented by Testing.T
|
||||
type T interface {
|
||||
Helper()
|
||||
Errorf(string, ...interface{})
|
||||
}
|
||||
|
||||
// CheckGet does not call Errorf() iff there is a string key with the
|
||||
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
|
||||
func (m *Miniredis) CheckGet(t T, key, expected string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Get(key)
|
||||
if err != nil {
|
||||
t.Errorf("GET error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if found != expected {
|
||||
t.Errorf("GET error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckList does not call Errorf() iff there is a list key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
|
||||
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.List(key)
|
||||
if err != nil {
|
||||
t.Errorf("List error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("List error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckSet does not call Errorf() iff there is a set key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
|
||||
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Members(key)
|
||||
if err != nil {
|
||||
t.Errorf("Set error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
sort.Strings(expected)
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("Set error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
67
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_cluster.go
generated
vendored
67
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_cluster.go
generated
vendored
|
|
@ -1,67 +0,0 @@
|
|||
// Commands from https://redis.io/commands#cluster
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsCluster handles some cluster operations.
|
||||
func commandsCluster(m *Miniredis) {
|
||||
m.srv.Register("CLUSTER", m.cmdCluster)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdCluster(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "SLOTS":
|
||||
m.cmdClusterSlots(c, cmd, args)
|
||||
case "KEYSLOT":
|
||||
m.cmdClusterKeySlot(c, cmd, args)
|
||||
case "NODES":
|
||||
m.cmdClusterNodes(c, cmd, args)
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("ERR 'CLUSTER %s' not supported", strings.Join(args, " ")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CLUSTER SLOTS
|
||||
func (m *Miniredis) cmdClusterSlots(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteLen(1)
|
||||
c.WriteLen(3)
|
||||
c.WriteInt(0)
|
||||
c.WriteInt(16383)
|
||||
c.WriteLen(3)
|
||||
c.WriteBulk(m.srv.Addr().IP.String())
|
||||
c.WriteInt(m.srv.Addr().Port)
|
||||
c.WriteBulk("09dbe9720cda62f7865eabc5fd8857c5d2678366")
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER KEYSLOT
|
||||
func (m *Miniredis) cmdClusterKeySlot(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteInt(163)
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER NODES
|
||||
func (m *Miniredis) cmdClusterNodes(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk("e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca 127.0.0.1:7000@7000 myself,master - 0 0 1 connected 0-16383")
|
||||
})
|
||||
}
|
||||
2045
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_command.go
generated
vendored
2045
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_command.go
generated
vendored
File diff suppressed because it is too large
Load Diff
284
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_connection.go
generated
vendored
284
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_connection.go
generated
vendored
|
|
@ -1,284 +0,0 @@
|
|||
// Commands from https://redis.io/commands#connection
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsConnection(m *Miniredis) {
|
||||
m.srv.Register("AUTH", m.cmdAuth)
|
||||
m.srv.Register("ECHO", m.cmdEcho)
|
||||
m.srv.Register("HELLO", m.cmdHello)
|
||||
m.srv.Register("PING", m.cmdPing)
|
||||
m.srv.Register("QUIT", m.cmdQuit)
|
||||
m.srv.Register("SELECT", m.cmdSelect)
|
||||
m.srv.Register("SWAPDB", m.cmdSwapdb)
|
||||
}
|
||||
|
||||
// PING
|
||||
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
payload := ""
|
||||
if len(args) > 0 {
|
||||
payload = args[0]
|
||||
}
|
||||
|
||||
// PING is allowed in subscribed state
|
||||
if sub := getCtx(c).subscriber; sub != nil {
|
||||
c.Block(func(c *server.Writer) {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("pong")
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if payload == "" {
|
||||
c.WriteInline("PONG")
|
||||
return
|
||||
}
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
}
|
||||
|
||||
// AUTH
|
||||
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 2 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
if getCtx(c).nested {
|
||||
c.WriteError(msgNotFromScripts)
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
username: "default",
|
||||
password: args[0],
|
||||
}
|
||||
if len(args) == 2 {
|
||||
opts.username, opts.password = args[0], args[1]
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
c.WriteError("ERR AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
|
||||
return
|
||||
}
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.authenticated = true
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HELLO
|
||||
func (m *Miniredis) cmdHello(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
version int
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.version, "ERR Protocol version is not an integer or out of range"); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
switch opts.version {
|
||||
case 2, 3:
|
||||
default:
|
||||
c.WriteError("NOPROTO unsupported protocol version")
|
||||
return
|
||||
}
|
||||
|
||||
var checkAuth bool
|
||||
for len(args) > 0 {
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "AUTH":
|
||||
if len(args) < 3 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
opts.username, opts.password, args = args[1], args[2], args[3:]
|
||||
checkAuth = true
|
||||
case "SETNAME":
|
||||
if len(args) < 2 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
_, args = args[1], args[2:]
|
||||
default:
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
// redis ignores legacy "AUTH" if it's not enabled.
|
||||
checkAuth = false
|
||||
}
|
||||
if checkAuth {
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
getCtx(c).authenticated = true
|
||||
}
|
||||
|
||||
c.Resp3 = opts.version == 3
|
||||
|
||||
c.WriteMapLen(7)
|
||||
c.WriteBulk("server")
|
||||
c.WriteBulk("miniredis")
|
||||
c.WriteBulk("version")
|
||||
c.WriteBulk("6.0.5")
|
||||
c.WriteBulk("proto")
|
||||
c.WriteInt(opts.version)
|
||||
c.WriteBulk("id")
|
||||
c.WriteInt(42)
|
||||
c.WriteBulk("mode")
|
||||
c.WriteBulk("standalone")
|
||||
c.WriteBulk("role")
|
||||
c.WriteBulk("master")
|
||||
c.WriteBulk("modules")
|
||||
c.WriteLen(0)
|
||||
}
|
||||
|
||||
// ECHO
|
||||
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
msg := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk(msg)
|
||||
})
|
||||
}
|
||||
|
||||
// SELECT
|
||||
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id int
|
||||
}
|
||||
if ok := optInt(c, args[0], &opts.id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.selectedDB = opts.id
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// SWAPDB
|
||||
func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id1 int
|
||||
id2 int
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.id1, "ERR invalid first DB index"); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.id2, "ERR invalid second DB index"); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id1 < 0 || opts.id2 < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
m.swapDB(opts.id1, opts.id2)
|
||||
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// QUIT
|
||||
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
|
||||
// QUIT isn't transactionfied and accepts any arguments.
|
||||
c.WriteOK()
|
||||
c.Close()
|
||||
}
|
||||
669
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_generic.go
generated
vendored
669
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_generic.go
generated
vendored
|
|
@ -1,669 +0,0 @@
|
|||
// Commands from https://redis.io/commands#generic
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
|
||||
func commandsGeneric(m *Miniredis) {
|
||||
m.srv.Register("COPY", m.cmdCopy)
|
||||
m.srv.Register("DEL", m.cmdDel)
|
||||
// DUMP
|
||||
m.srv.Register("EXISTS", m.cmdExists)
|
||||
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
|
||||
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
|
||||
m.srv.Register("KEYS", m.cmdKeys)
|
||||
// MIGRATE
|
||||
m.srv.Register("MOVE", m.cmdMove)
|
||||
// OBJECT
|
||||
m.srv.Register("PERSIST", m.cmdPersist)
|
||||
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
|
||||
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
|
||||
m.srv.Register("PTTL", m.cmdPTTL)
|
||||
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
|
||||
m.srv.Register("RENAME", m.cmdRename)
|
||||
m.srv.Register("RENAMENX", m.cmdRenamenx)
|
||||
// RESTORE
|
||||
m.srv.Register("TOUCH", m.cmdTouch)
|
||||
m.srv.Register("TTL", m.cmdTTL)
|
||||
m.srv.Register("TYPE", m.cmdType)
|
||||
m.srv.Register("SCAN", m.cmdScan)
|
||||
// SORT
|
||||
m.srv.Register("UNLINK", m.cmdDel)
|
||||
}
|
||||
|
||||
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
|
||||
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
|
||||
// converted to a duration.
|
||||
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
|
||||
return func(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
value int
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.value); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
// Key must be present.
|
||||
if _, ok := db.keys[opts.key]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if unix {
|
||||
db.ttl[opts.key] = m.at(opts.value, d)
|
||||
} else {
|
||||
db.ttl[opts.key] = time.Duration(opts.value) * d
|
||||
}
|
||||
db.keyVersion[opts.key]++
|
||||
db.checkTTL(opts.key)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TOUCH
|
||||
func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TTL
|
||||
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// No such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Seconds()))
|
||||
})
|
||||
}
|
||||
|
||||
// PTTL
|
||||
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Nanoseconds() / 1000000))
|
||||
})
|
||||
}
|
||||
|
||||
// PERSIST
|
||||
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.ttl[key]; !ok {
|
||||
// no expire value
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
delete(db.ttl, key)
|
||||
db.keyVersion[key]++
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// DEL and UNLINK
|
||||
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
db.del(key, true) // delete expire
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TYPE
|
||||
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError("usage error")
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInline("none")
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInline(t)
|
||||
})
|
||||
}
|
||||
|
||||
// EXISTS
|
||||
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
found := 0
|
||||
for _, k := range args {
|
||||
if db.exists(k) {
|
||||
found++
|
||||
}
|
||||
}
|
||||
c.WriteInt(found)
|
||||
})
|
||||
}
|
||||
|
||||
// MOVE
|
||||
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
targetDB int
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
opts.targetDB, _ = strconv.Atoi(args[1])
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if ctx.selectedDB == opts.targetDB {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
db := m.db(ctx.selectedDB)
|
||||
targetDB := m.db(opts.targetDB)
|
||||
|
||||
if !db.move(opts.key, targetDB) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// KEYS
|
||||
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
keys, _ := matchKeys(db.allKeys(), key)
|
||||
c.WriteLen(len(keys))
|
||||
for _, s := range keys {
|
||||
c.WriteBulk(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RANDOMKEY
|
||||
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(db.keys) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
nr := m.randIntn(len(db.keys))
|
||||
for k := range db.keys {
|
||||
if nr == 0 {
|
||||
c.WriteBulk(k)
|
||||
return
|
||||
}
|
||||
nr--
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RENAME
|
||||
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// RENAMENX
|
||||
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if db.exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// SCAN
|
||||
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
cursor int
|
||||
withMatch bool
|
||||
match string
|
||||
withType bool
|
||||
_type string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
// MATCH, COUNT and TYPE options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
// we do nothing with count
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if _, err := strconv.Atoi(args[1]); err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "type" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withType = true
|
||||
opts._type, args = strings.ToLower(args[1]), args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// We return _all_ (matched) keys every time.
|
||||
|
||||
if opts.cursor != 0 {
|
||||
// Invalid cursor.
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
if opts.withType {
|
||||
keys = make([]string, 0)
|
||||
for k, t := range db.keys {
|
||||
// type must be given exactly; no pattern matching is performed
|
||||
if t == opts._type {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys) // To make things deterministic.
|
||||
} else {
|
||||
keys = db.allKeys()
|
||||
}
|
||||
|
||||
if opts.withMatch {
|
||||
keys, _ = matchKeys(keys, opts.match)
|
||||
}
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(len(keys))
|
||||
for _, k := range keys {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// COPY
|
||||
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
from string
|
||||
to string
|
||||
destinationDB int
|
||||
replace bool
|
||||
}{
|
||||
destinationDB: -1,
|
||||
}
|
||||
|
||||
opts.from, opts.to, args = args[0], args[1], args[2:]
|
||||
for len(args) > 0 {
|
||||
switch strings.ToLower(args[0]) {
|
||||
case "db":
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
db, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if db < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.destinationDB = db
|
||||
args = args[2:]
|
||||
case "replace":
|
||||
opts.replace = true
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
fromDB, toDB := ctx.selectedDB, opts.destinationDB
|
||||
if toDB == -1 {
|
||||
toDB = fromDB
|
||||
}
|
||||
|
||||
if fromDB == toDB && opts.from == opts.to {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
|
||||
if !m.db(fromDB).exists(opts.from) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if !opts.replace {
|
||||
if m.db(toDB).exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
609
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_geo.go
generated
vendored
609
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_geo.go
generated
vendored
|
|
@ -1,609 +0,0 @@
|
|||
// Commands from https://redis.io/commands#geo
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeo handles GEOADD, GEORADIUS etc.
|
||||
func commandsGeo(m *Miniredis) {
|
||||
m.srv.Register("GEOADD", m.cmdGeoadd)
|
||||
m.srv.Register("GEODIST", m.cmdGeodist)
|
||||
m.srv.Register("GEOPOS", m.cmdGeopos)
|
||||
m.srv.Register("GEORADIUS", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUS_RO", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember)
|
||||
m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember)
|
||||
}
|
||||
|
||||
// GEOADD
|
||||
func (m *Miniredis) cmdGeoadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 || len(args[1:])%3 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
toSet := map[string]float64{}
|
||||
for len(args) > 2 {
|
||||
rawLong, rawLat, name := args[0], args[1], args[2]
|
||||
args = args[3:]
|
||||
longitude, err := strconv.ParseFloat(rawLong, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(rawLat, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
|
||||
if latitude < -85.05112878 ||
|
||||
latitude > 85.05112878 ||
|
||||
longitude < -180 ||
|
||||
longitude > 180 {
|
||||
c.WriteError(fmt.Sprintf("ERR invalid longitude,latitude pair %.6f,%.6f", longitude, latitude))
|
||||
return
|
||||
}
|
||||
|
||||
toSet[name] = float64(toGeohash(longitude, latitude))
|
||||
}
|
||||
|
||||
set := 0
|
||||
for name, score := range toSet {
|
||||
if db.ssetAdd(key, score, name) {
|
||||
set++
|
||||
}
|
||||
}
|
||||
c.WriteInt(set)
|
||||
})
|
||||
}
|
||||
|
||||
// GEODIST
|
||||
func (m *Miniredis) cmdGeodist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, from, to, args := args[0], args[1], args[2], args[3:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
unit := "m"
|
||||
if len(args) > 0 {
|
||||
unit, args = args[0], args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
toMeter := parseUnit(unit)
|
||||
if toMeter == 0 {
|
||||
c.WriteError(msgUnsupportedUnit)
|
||||
return
|
||||
}
|
||||
|
||||
members := db.sortedsetKeys[key]
|
||||
fromD, okFrom := members.get(from)
|
||||
toD, okTo := members.get(to)
|
||||
if !okFrom || !okTo {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
fromLo, fromLat := fromGeohash(uint64(fromD))
|
||||
toLo, toLat := fromGeohash(uint64(toD))
|
||||
|
||||
dist := distance(fromLat, fromLo, toLat, toLo) / toMeter
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", dist))
|
||||
})
|
||||
}
|
||||
|
||||
// GEOPOS
|
||||
func (m *Miniredis) cmdGeopos(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(args))
|
||||
for _, l := range args {
|
||||
if !db.ssetExists(key, l) {
|
||||
c.WriteLen(-1)
|
||||
continue
|
||||
}
|
||||
score := db.ssetScore(key, l)
|
||||
c.WriteLen(2)
|
||||
long, lat := fromGeohash(uint64(score))
|
||||
c.WriteBulk(fmt.Sprintf("%f", long))
|
||||
c.WriteBulk(fmt.Sprintf("%f", lat))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type geoDistance struct {
|
||||
Name string
|
||||
Score float64
|
||||
Distance float64
|
||||
Longitude float64
|
||||
Latitude float64
|
||||
}
|
||||
|
||||
// GEORADIUS and GEORADIUS_RO
|
||||
func (m *Miniredis) cmdGeoradius(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 5 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
longitude, err := strconv.ParseFloat(args[1], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
radius, err := strconv.ParseFloat(args[3], 64)
|
||||
if err != nil || radius < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
toMeter := parseUnit(args[4])
|
||||
if toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[5:]
|
||||
|
||||
var opts struct {
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUS_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
members := db.ssetElements(key)
|
||||
|
||||
matches := withinRadius(members, longitude, latitude, radius*toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GEORADIUSBYMEMBER and GEORADIUSBYMEMBER_RO
|
||||
func (m *Miniredis) cmdGeoradiusbymember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
member string
|
||||
radius float64
|
||||
toMeter float64
|
||||
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}{
|
||||
key: args[0],
|
||||
member: args[1],
|
||||
}
|
||||
|
||||
r, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil || r < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
opts.radius = r
|
||||
|
||||
opts.toMeter = parseUnit(args[3])
|
||||
if opts.toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[4:]
|
||||
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUSBYMEMBER_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(opts.key) != "zset" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// get position of member
|
||||
if !db.ssetExists(opts.key, opts.member) {
|
||||
c.WriteError("ERR could not decode requested zset member")
|
||||
return
|
||||
}
|
||||
score := db.ssetScore(opts.key, opts.member)
|
||||
longitude, latitude := fromGeohash(uint64(score))
|
||||
|
||||
members := db.ssetElements(opts.key)
|
||||
matches := withinRadius(members, longitude, latitude, opts.radius*opts.toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/opts.toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/opts.toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func withinRadius(members []ssElem, longitude, latitude, radius float64) []geoDistance {
|
||||
matches := []geoDistance{}
|
||||
for _, el := range members {
|
||||
elLo, elLat := fromGeohash(uint64(el.score))
|
||||
distanceInMeter := distance(latitude, longitude, elLat, elLo)
|
||||
|
||||
if distanceInMeter <= radius {
|
||||
matches = append(matches, geoDistance{
|
||||
Name: el.member,
|
||||
Score: el.score,
|
||||
Distance: distanceInMeter,
|
||||
Longitude: elLo,
|
||||
Latitude: elLat,
|
||||
})
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
func parseUnit(u string) float64 {
|
||||
switch u {
|
||||
case "m":
|
||||
return 1
|
||||
case "km":
|
||||
return 1000
|
||||
case "mi":
|
||||
return 1609.34
|
||||
case "ft":
|
||||
return 0.3048
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
683
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_hash.go
generated
vendored
683
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_hash.go
generated
vendored
|
|
@ -1,683 +0,0 @@
|
|||
// Commands from https://redis.io/commands#hash
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsHash handles all hash value operations.
|
||||
func commandsHash(m *Miniredis) {
|
||||
m.srv.Register("HDEL", m.cmdHdel)
|
||||
m.srv.Register("HEXISTS", m.cmdHexists)
|
||||
m.srv.Register("HGET", m.cmdHget)
|
||||
m.srv.Register("HGETALL", m.cmdHgetall)
|
||||
m.srv.Register("HINCRBY", m.cmdHincrby)
|
||||
m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat)
|
||||
m.srv.Register("HKEYS", m.cmdHkeys)
|
||||
m.srv.Register("HLEN", m.cmdHlen)
|
||||
m.srv.Register("HMGET", m.cmdHmget)
|
||||
m.srv.Register("HMSET", m.cmdHmset)
|
||||
m.srv.Register("HSET", m.cmdHset)
|
||||
m.srv.Register("HSETNX", m.cmdHsetnx)
|
||||
m.srv.Register("HSTRLEN", m.cmdHstrlen)
|
||||
m.srv.Register("HVALS", m.cmdHvals)
|
||||
m.srv.Register("HSCAN", m.cmdHscan)
|
||||
}
|
||||
|
||||
// HSET
|
||||
func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, pairs := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(pairs)%2 == 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
new := db.hashSet(key, pairs...)
|
||||
c.WriteInt(new)
|
||||
})
|
||||
}
|
||||
|
||||
// HSETNX
|
||||
func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
value string
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
value: args[2],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.hashKeys[opts.key]; !ok {
|
||||
db.hashKeys[opts.key] = map[string]string{}
|
||||
db.keys[opts.key] = "hash"
|
||||
}
|
||||
_, ok := db.hashKeys[opts.key][opts.field]
|
||||
if ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
db.hashKeys[opts.key][opts.field] = opts.value
|
||||
db.keyVersion[opts.key]++
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// HMSET
|
||||
func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
if len(args)%2 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
for len(args) > 0 {
|
||||
field, value := args[0], args[1]
|
||||
args = args[2:]
|
||||
db.hashSet(key, field, value)
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HGET
|
||||
func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, field := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
value, ok := db.hashKeys[key][field]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(value)
|
||||
})
|
||||
}
|
||||
|
||||
// HDEL
|
||||
func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
fields []string
|
||||
}{
|
||||
key: args[0],
|
||||
fields: args[1:],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
// No key is zero deleted
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
deleted := 0
|
||||
for _, f := range opts.fields {
|
||||
_, ok := db.hashKeys[opts.key][f]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delete(db.hashKeys[opts.key], f)
|
||||
deleted++
|
||||
}
|
||||
c.WriteInt(deleted)
|
||||
|
||||
// Nothing left. Remove the whole key.
|
||||
if len(db.hashKeys[opts.key]) == 0 {
|
||||
db.del(opts.key, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HEXISTS
|
||||
func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.hashKeys[opts.key][opts.field]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// HGETALL
|
||||
func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteMapLen(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteMapLen(len(db.hashKeys[key]))
|
||||
for _, k := range db.hashFields(key) {
|
||||
c.WriteBulk(k)
|
||||
c.WriteBulk(db.hashGet(key, k))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HKEYS
|
||||
func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
if db.t(key) != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
fields := db.hashFields(key)
|
||||
c.WriteLen(len(fields))
|
||||
for _, f := range fields {
|
||||
c.WriteBulk(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HSTRLEN
|
||||
func (m *Miniredis) cmdHstrlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
hash, key := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[hash]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
keys := db.hashKeys[hash]
|
||||
c.WriteInt(len(keys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// HVALS
|
||||
func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
vals := db.hashValues(key)
|
||||
c.WriteLen(len(vals))
|
||||
for _, v := range vals {
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HLEN
|
||||
func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(len(db.hashKeys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// HMGET
|
||||
func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
f, ok := db.hashKeys[key]
|
||||
if !ok {
|
||||
f = map[string]string{}
|
||||
}
|
||||
|
||||
c.WriteLen(len(args) - 1)
|
||||
for _, k := range args[1:] {
|
||||
v, ok := f[k]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
continue
|
||||
}
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HINCRBY
|
||||
func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
delta int
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.delta); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := db.hashIncr(opts.key, opts.field, opts.delta)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteInt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// HINCRBYFLOAT
|
||||
func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
delta *big.Float
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
delta, _, err := big.ParseFloat(args[2], 10, 128, 0)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidFloat)
|
||||
return
|
||||
}
|
||||
opts.delta = delta
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "hash" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := db.hashIncrfloat(opts.key, opts.field, opts.delta)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteBulk(formatBig(v))
|
||||
})
|
||||
}
|
||||
|
||||
// HSCAN
|
||||
func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
cursor int
|
||||
withMatch bool
|
||||
match string
|
||||
}{
|
||||
key: args[0],
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
|
||||
// MATCH and COUNT options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
// we do nothing with count
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
_, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// return _all_ (matched) keys every time
|
||||
|
||||
if opts.cursor != 0 {
|
||||
// Invalid cursor.
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
if db.exists(opts.key) && db.t(opts.key) != "hash" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.hashFields(opts.key)
|
||||
if opts.withMatch {
|
||||
members, _ = matchKeys(members, opts.match)
|
||||
}
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
// HSCAN gives key, values.
|
||||
c.WriteLen(len(members) * 2)
|
||||
for _, k := range members {
|
||||
c.WriteBulk(k)
|
||||
c.WriteBulk(db.hashGet(opts.key, k))
|
||||
}
|
||||
})
|
||||
}
|
||||
95
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_hll.go
generated
vendored
95
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_hll.go
generated
vendored
|
|
@ -1,95 +0,0 @@
|
|||
package miniredis
|
||||
|
||||
import "github.com/alicebob/miniredis/v2/server"
|
||||
|
||||
// commandsHll handles all hll related operations.
|
||||
func commandsHll(m *Miniredis) {
|
||||
m.srv.Register("PFADD", m.cmdPfadd)
|
||||
m.srv.Register("PFCOUNT", m.cmdPfcount)
|
||||
m.srv.Register("PFMERGE", m.cmdPfmerge)
|
||||
}
|
||||
|
||||
// PFADD
|
||||
func (m *Miniredis) cmdPfadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, items := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "hll" {
|
||||
c.WriteError(ErrNotValidHllValue.Error())
|
||||
return
|
||||
}
|
||||
|
||||
altered := db.hllAdd(key, items...)
|
||||
c.WriteInt(altered)
|
||||
})
|
||||
}
|
||||
|
||||
// PFCOUNT
|
||||
func (m *Miniredis) cmdPfcount(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count, err := db.hllCount(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// PFMERGE
|
||||
func (m *Miniredis) cmdPfmerge(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if err := db.hllMerge(keys); err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
40
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_info.go
generated
vendored
40
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_info.go
generated
vendored
|
|
@ -1,40 +0,0 @@
|
|||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// Command 'INFO' from https://redis.io/commands/info/
|
||||
func (m *Miniredis) cmdInfo(c *server.Peer, cmd string, args []string) {
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
const (
|
||||
clientsSectionName = "clients"
|
||||
clientsSectionContent = "# Clients\nconnected_clients:%d\r\n"
|
||||
)
|
||||
|
||||
var result string
|
||||
|
||||
for _, key := range args {
|
||||
if key != clientsSectionName {
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("section (%s) is not supported", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
result = fmt.Sprintf(clientsSectionContent, m.Server().ClientsLen())
|
||||
|
||||
c.WriteBulk(result)
|
||||
})
|
||||
}
|
||||
986
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_list.go
generated
vendored
986
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/cmd_list.go
generated
vendored
|
|
@ -1,986 +0,0 @@
|
|||
// Commands from https://redis.io/commands#list
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
type leftright int
|
||||
|
||||
const (
|
||||
left leftright = iota
|
||||
right
|
||||
)
|
||||
|
||||
// commandsList handles list commands (mostly L*)
|
||||
func commandsList(m *Miniredis) {
|
||||
m.srv.Register("BLPOP", m.cmdBlpop)
|
||||
m.srv.Register("BRPOP", m.cmdBrpop)
|
||||
m.srv.Register("BRPOPLPUSH", m.cmdBrpoplpush)
|
||||
m.srv.Register("LINDEX", m.cmdLindex)
|
||||
m.srv.Register("LPOS", m.cmdLpos)
|
||||
m.srv.Register("LINSERT", m.cmdLinsert)
|
||||
m.srv.Register("LLEN", m.cmdLlen)
|
||||
m.srv.Register("LPOP", m.cmdLpop)
|
||||
m.srv.Register("LPUSH", m.cmdLpush)
|
||||
m.srv.Register("LPUSHX", m.cmdLpushx)
|
||||
m.srv.Register("LRANGE", m.cmdLrange)
|
||||
m.srv.Register("LREM", m.cmdLrem)
|
||||
m.srv.Register("LSET", m.cmdLset)
|
||||
m.srv.Register("LTRIM", m.cmdLtrim)
|
||||
m.srv.Register("RPOP", m.cmdRpop)
|
||||
m.srv.Register("RPOPLPUSH", m.cmdRpoplpush)
|
||||
m.srv.Register("RPUSH", m.cmdRpush)
|
||||
m.srv.Register("RPUSHX", m.cmdRpushx)
|
||||
m.srv.Register("LMOVE", m.cmdLmove)
|
||||
}
|
||||
|
||||
// BLPOP
|
||||
func (m *Miniredis) cmdBlpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdBXpop(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// BRPOP
|
||||
func (m *Miniredis) cmdBrpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdBXpop(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
timeoutS := args[len(args)-1]
|
||||
keys := args[:len(args)-1]
|
||||
|
||||
timeout, err := strconv.Atoi(timeoutS)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidTimeout)
|
||||
return
|
||||
}
|
||||
if timeout < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgNegTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
blocking(
|
||||
m,
|
||||
c,
|
||||
time.Duration(timeout)*time.Second,
|
||||
func(c *server.Peer, ctx *connCtx) bool {
|
||||
db := m.db(ctx.selectedDB)
|
||||
for _, key := range keys {
|
||||
if !db.exists(key) {
|
||||
continue
|
||||
}
|
||||
if db.t(key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(db.listKeys[key]) == 0 {
|
||||
continue
|
||||
}
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(key)
|
||||
var v string
|
||||
switch lr {
|
||||
case left:
|
||||
v = db.listLpop(key)
|
||||
case right:
|
||||
v = db.listPop(key)
|
||||
}
|
||||
c.WriteBulk(v)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
func(c *server.Peer) {
|
||||
// timeout
|
||||
c.WriteLen(-1)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// LINDEX
|
||||
func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, offsets := args[0], args[1]
|
||||
|
||||
offset, err := strconv.Atoi(offsets)
|
||||
if err != nil || offsets == "-0" {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[key]
|
||||
if offset < 0 {
|
||||
offset = len(l) + offset
|
||||
}
|
||||
if offset < 0 || offset > len(l)-1 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(l[offset])
|
||||
})
|
||||
}
|
||||
|
||||
// LPOS key element [RANK rank] [COUNT num-matches] [MAXLEN len]
|
||||
func (m *Miniredis) cmdLpos(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
// Extract options from arguments if present.
|
||||
//
|
||||
// Redis allows duplicate options and uses the last specified.
|
||||
// `LPOS key term RANK 1 RANK 2` is effectively the same as
|
||||
// `LPOS key term RANK 2`
|
||||
if len(args)%2 == 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
rank, count := 1, 1 // Default values
|
||||
var maxlen int // Default value is the list length (see below)
|
||||
var countSpecified, maxlenSpecified bool
|
||||
if len(args) > 2 {
|
||||
for i := 2; i < len(args); i++ {
|
||||
if i%2 == 0 {
|
||||
val := args[i+1]
|
||||
var err error
|
||||
switch strings.ToLower(args[i]) {
|
||||
case "rank":
|
||||
if rank, err = strconv.Atoi(val); err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if rank == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgRankIsZero)
|
||||
return
|
||||
}
|
||||
case "count":
|
||||
countSpecified = true
|
||||
if count, err = strconv.Atoi(val); err != nil || count < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgCountIsNegative)
|
||||
return
|
||||
}
|
||||
case "maxlen":
|
||||
maxlenSpecified = true
|
||||
if maxlen, err = strconv.Atoi(val); err != nil || maxlen < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgMaxLengthIsNegative)
|
||||
return
|
||||
}
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
key, element := args[0], args[1]
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
l := db.listKeys[key]
|
||||
|
||||
// RANK cannot be zero (see above).
|
||||
// If RANK is positive search forward (left to right).
|
||||
// If RANK is negative search backward (right to left).
|
||||
// Iterator returns true to continue iterating.
|
||||
iterate := func(iterator func(i int, e string) bool) {
|
||||
comparisons := len(l)
|
||||
// Only use max length if specified, not zero, and less than total length.
|
||||
// When max length is specified, but is zero, this means "unlimited".
|
||||
if maxlenSpecified && maxlen != 0 && maxlen < len(l) {
|
||||
comparisons = maxlen
|
||||
}
|
||||
if rank > 0 {
|
||||
for i := 0; i < comparisons; i++ {
|
||||
if resume := iterator(i, l[i]); !resume {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if rank < 0 {
|
||||
start := len(l) - 1
|
||||
end := len(l) - comparisons
|
||||
for i := start; i >= end; i-- {
|
||||
if resume := iterator(i, l[i]); !resume {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var currentRank, currentCount int
|
||||
vals := make([]int, 0, count)
|
||||
iterate(func(i int, e string) bool {
|
||||
if e == element {
|
||||
currentRank++
|
||||
// Only collect values only after surpassing the absolute value of rank.
|
||||
if rank > 0 && currentRank < rank {
|
||||
return true
|
||||
}
|
||||
if rank < 0 && currentRank < -rank {
|
||||
return true
|
||||
}
|
||||
vals = append(vals, i)
|
||||
currentCount++
|
||||
if currentCount == count {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !countSpecified && len(vals) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if !countSpecified && len(vals) == 1 {
|
||||
c.WriteInt(vals[0])
|
||||
return
|
||||
}
|
||||
c.WriteLen(len(vals))
|
||||
for _, val := range vals {
|
||||
c.WriteInt(val)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// LINSERT
|
||||
func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
where := 0
|
||||
switch strings.ToLower(args[1]) {
|
||||
case "before":
|
||||
where = -1
|
||||
case "after":
|
||||
where = +1
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
pivot := args[2]
|
||||
value := args[3]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[key]
|
||||
for i, el := range l {
|
||||
if el != pivot {
|
||||
continue
|
||||
}
|
||||
|
||||
if where < 0 {
|
||||
l = append(l[:i], append(listKey{value}, l[i:]...)...)
|
||||
} else {
|
||||
if i == len(l)-1 {
|
||||
l = append(l, value)
|
||||
} else {
|
||||
l = append(l[:i+1], append(listKey{value}, l[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
db.listKeys[key] = l
|
||||
db.keyVersion[key]++
|
||||
c.WriteInt(len(l))
|
||||
return
|
||||
}
|
||||
c.WriteInt(-1)
|
||||
})
|
||||
}
|
||||
|
||||
// LLEN
|
||||
func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
// No such key. That's zero length.
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(len(db.listKeys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// LPOP
|
||||
func (m *Miniredis) cmdLpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpop(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// RPOP
|
||||
func (m *Miniredis) cmdRpop(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpop(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
withCount bool
|
||||
count int
|
||||
}
|
||||
|
||||
opts.key, args = args[0], args[1:]
|
||||
if len(args) > 0 {
|
||||
if ok := optInt(c, args[0], &opts.count); !ok {
|
||||
return
|
||||
}
|
||||
if opts.count < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.withCount = true
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
// non-existing key is fine
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(opts.key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if opts.withCount {
|
||||
var popped []string
|
||||
for opts.count > 0 && len(db.listKeys[opts.key]) > 0 {
|
||||
switch lr {
|
||||
case left:
|
||||
popped = append(popped, db.listLpop(opts.key))
|
||||
case right:
|
||||
popped = append(popped, db.listPop(opts.key))
|
||||
}
|
||||
opts.count -= 1
|
||||
}
|
||||
if len(popped) == 0 {
|
||||
c.WriteLen(-1)
|
||||
} else {
|
||||
c.WriteStrings(popped)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var elem string
|
||||
switch lr {
|
||||
case left:
|
||||
elem = db.listLpop(opts.key)
|
||||
case right:
|
||||
elem = db.listPop(opts.key)
|
||||
}
|
||||
c.WriteBulk(elem)
|
||||
})
|
||||
}
|
||||
|
||||
// LPUSH
|
||||
func (m *Miniredis) cmdLpush(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpush(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// RPUSH
|
||||
func (m *Miniredis) cmdRpush(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpush(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
var newLen int
|
||||
for _, value := range args {
|
||||
switch lr {
|
||||
case left:
|
||||
newLen = db.listLpush(key, value)
|
||||
case right:
|
||||
newLen = db.listPush(key, value)
|
||||
}
|
||||
}
|
||||
c.WriteInt(newLen)
|
||||
})
|
||||
}
|
||||
|
||||
// LPUSHX
|
||||
func (m *Miniredis) cmdLpushx(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpushx(c, cmd, args, left)
|
||||
}
|
||||
|
||||
// RPUSHX
|
||||
func (m *Miniredis) cmdRpushx(c *server.Peer, cmd string, args []string) {
|
||||
m.cmdXpushx(c, cmd, args, right)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr leftright) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if db.t(key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
var newLen int
|
||||
for _, value := range args {
|
||||
switch lr {
|
||||
case left:
|
||||
newLen = db.listLpush(key, value)
|
||||
case right:
|
||||
newLen = db.listPush(key, value)
|
||||
}
|
||||
}
|
||||
c.WriteInt(newLen)
|
||||
})
|
||||
}
|
||||
|
||||
// LRANGE
|
||||
func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
start int
|
||||
end int
|
||||
}{
|
||||
key: args[0],
|
||||
}
|
||||
if ok := optInt(c, args[1], &opts.start); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.end); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
if len(l) == 0 {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
|
||||
rs, re := redisRange(len(l), opts.start, opts.end, false)
|
||||
c.WriteLen(re - rs)
|
||||
for _, el := range l[rs:re] {
|
||||
c.WriteBulk(el)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// LREM
|
||||
func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
count int
|
||||
value string
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.count); !ok {
|
||||
return
|
||||
}
|
||||
opts.value = args[2]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if db.t(opts.key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
if opts.count < 0 {
|
||||
reverseSlice(l)
|
||||
}
|
||||
deleted := 0
|
||||
newL := []string{}
|
||||
toDelete := len(l)
|
||||
if opts.count < 0 {
|
||||
toDelete = -opts.count
|
||||
}
|
||||
if opts.count > 0 {
|
||||
toDelete = opts.count
|
||||
}
|
||||
for _, el := range l {
|
||||
if el == opts.value {
|
||||
if toDelete > 0 {
|
||||
deleted++
|
||||
toDelete--
|
||||
continue
|
||||
}
|
||||
}
|
||||
newL = append(newL, el)
|
||||
}
|
||||
if opts.count < 0 {
|
||||
reverseSlice(newL)
|
||||
}
|
||||
if len(newL) == 0 {
|
||||
db.del(opts.key, true)
|
||||
} else {
|
||||
db.listKeys[opts.key] = newL
|
||||
db.keyVersion[opts.key]++
|
||||
}
|
||||
|
||||
c.WriteInt(deleted)
|
||||
})
|
||||
}
|
||||
|
||||
// LSET
|
||||
func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
index int
|
||||
value string
|
||||
}
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.index); !ok {
|
||||
return
|
||||
}
|
||||
opts.value = args[2]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
if db.t(opts.key) != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
index := opts.index
|
||||
if index < 0 {
|
||||
index = len(l) + index
|
||||
}
|
||||
if index < 0 || index > len(l)-1 {
|
||||
c.WriteError(msgOutOfRange)
|
||||
return
|
||||
}
|
||||
l[index] = opts.value
|
||||
db.keyVersion[opts.key]++
|
||||
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// LTRIM
|
||||
func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
if ok := optInt(c, args[1], &opts.start); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.end); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
c.WriteOK()
|
||||
return
|
||||
}
|
||||
if t != "list" {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
l := db.listKeys[opts.key]
|
||||
rs, re := redisRange(len(l), opts.start, opts.end, false)
|
||||
l = l[rs:re]
|
||||
if len(l) == 0 {
|
||||
db.del(opts.key, true)
|
||||
} else {
|
||||
db.listKeys[opts.key] = l
|
||||
db.keyVersion[opts.key]++
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// RPOPLPUSH
|
||||
func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
src, dst := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(src) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(src) != "list" || (db.exists(dst) && db.t(dst) != "list") {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
elem := db.listPop(src)
|
||||
db.listLpush(dst, elem)
|
||||
c.WriteBulk(elem)
|
||||
})
|
||||
}
|
||||
|
||||
// BRPOPLPUSH
|
||||
func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
src string
|
||||
dst string
|
||||
timeout int
|
||||
}
|
||||
opts.src = args[0]
|
||||
opts.dst = args[1]
|
||||
if ok := optIntErr(c, args[2], &opts.timeout, msgInvalidTimeout); !ok {
|
||||
return
|
||||
}
|
||||
if opts.timeout < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgNegTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
blocking(
|
||||
m,
|
||||
c,
|
||||
time.Duration(opts.timeout)*time.Second,
|
||||
func(c *server.Peer, ctx *connCtx) bool {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.src) {
|
||||
return false
|
||||
}
|
||||
if db.t(opts.src) != "list" || (db.exists(opts.dst) && db.t(opts.dst) != "list") {
|
||||
c.WriteError(msgWrongType)
|
||||
return true
|
||||
}
|
||||
if len(db.listKeys[opts.src]) == 0 {
|
||||
return false
|
||||
}
|
||||
elem := db.listPop(opts.src)
|
||||
db.listLpush(opts.dst, elem)
|
||||
c.WriteBulk(elem)
|
||||
return true
|
||||
},
|
||||
func(c *server.Peer) {
|
||||
// timeout
|
||||
c.WriteLen(-1)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// LMOVE
|
||||
func (m *Miniredis) cmdLmove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
src string
|
||||
dst string
|
||||
srcDir string
|
||||
dstDir string
|
||||
}{
|
||||
src: args[0],
|
||||
dst: args[1],
|
||||
srcDir: strings.ToLower(args[2]),
|
||||
dstDir: strings.ToLower(args[3]),
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.src) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(opts.src) != "list" || (db.exists(opts.dst) && db.t(opts.dst) != "list") {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
var elem string
|
||||
switch opts.srcDir {
|
||||
case "left":
|
||||
elem = db.listLpop(opts.src)
|
||||
case "right":
|
||||
elem = db.listPop(opts.src)
|
||||
default:
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
switch opts.dstDir {
|
||||
case "left":
|
||||
db.listLpush(opts.dst, elem)
|
||||
case "right":
|
||||
db.listPush(opts.dst, elem)
|
||||
default:
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
c.WriteBulk(elem)
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user