feat(mqtt): change proj structure to services and add mochi broker

This commit is contained in:
Leandro Antônio Farias Machado 2023-03-20 11:19:02 -03:00
parent 26c523fc10
commit 46999043a5
1020 changed files with 379507 additions and 4854 deletions

View File

@ -16,3 +16,4 @@ go.work
*.txt
*.pwd
*.acl
.idea

View File

@ -25,9 +25,9 @@ func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("Starting Oktopus Project TR-369 Controller Version:", VERSION)
flBroker := flag.Bool("m", false, "Defines if mosquitto container must run or not")
//flBroker := flag.Bool("m", false, "Defines if mosquitto container must run or not")
// fl_endpointId := flag.String("endpoint_id", "proto::oktopus-controller", "Defines the enpoint id the Agent must trust on.")
flSubTopic := flag.String("s", "oktopus/+/+/agent", "That's the topic agent must publish to, and the controller keeps on listening.")
flSubTopic := flag.String("s", "oktopus/+/agent/+", "That's the topic agent must publish to, and the controller keeps on listening.")
// fl_pub_topic := flag.String("pub_topic", "oktopus/v1/controller", "That's the topic controller must publish to, and the agent keeps on listening.")
flBrokerAddr := flag.String("a", "localhost", "Mqtt broker adrress")
flBrokerPort := flag.String("p", "1883", "Mqtt broker port")
@ -44,10 +44,10 @@ func main() {
flag.Usage()
os.Exit(0)
}
if *flBroker {
log.Println("Starting Mqtt Broker")
mqtt.StartMqttBroker()
}
//if *flBroker {
// log.Println("Starting Mqtt Broker")
// mqtt.StartMqttBroker()
//}
/*
This context suppress our needs, but we can use a more sofisticate
approach with cancel and timeout options passing it through paho mqtt functions.
@ -69,8 +69,6 @@ func main() {
CA: *flTlsCert,
}
log.Println()
mtp.MtpService(&mqttClient, done)
<-done

View File

@ -81,6 +81,12 @@ func startClient(addr string, port string, tlsCa string, ctx context.Context, ms
clientConfig := paho.ClientConfig{
Conn: conn,
Router: singleHandler,
OnServerDisconnect: func(disconnect *paho.Disconnect) {
log.Println("disconnected from mqtt server, reason code: ", disconnect.ReasonCode)
},
OnClientError: func(err error) {
log.Println(err)
},
}
return paho.NewClient(clientConfig)
}

View File

@ -15,7 +15,6 @@ func StartMqttBroker() {
//TODO: Start Container through Docker SDK for GO, eliminating docker-compose and shell comands.
//TODO: Create Broker with user, password and CA certificate.
//TODO: Set broker access control list to topics.
//TODO: Set MQTTv5 CONNACK packet with topic for agent to use.
cmd := exec.Command("sudo", "docker", "compose", "-f", "internal/mosquitto/docker-compose.yml", "up", "-d")

View File

@ -0,0 +1,43 @@
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.19
- name: Vet
run: go vet ./...
- name: Test
run: go test -race ./... && echo true
coverage:
runs-on: ubuntu-latest
steps:
- name: Install Go
if: success()
uses: actions/setup-go@v2
with:
go-version: 1.19.x
- name: Checkout code
uses: actions/checkout@v2
- name: Calc coverage
run: |
go test -v -covermode=count -coverprofile=coverage.out ./...
- name: Convert coverage.out to coverage.lcov
uses: jandelgado/gcov2lcov-action@v1.0.6
- name: Coveralls
uses: coverallsapp/github-action@v1.1.2
with:
github-token: ${{ secrets.github_token }}
path-to-lcov: coverage.lcov

3
backend/services/mochi/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
cmd/mqtt
.DS_Store
*.db

View File

@ -0,0 +1,103 @@
linters:
disable-all: false
fix: false # Fix found issues (if it's supported by the linter).
enable:
# - asasalint
# - asciicheck
# - bidichk
# - bodyclose
# - containedctx
# - contextcheck
#- cyclop
# - deadcode
- decorder
# - depguard
# - dogsled
# - dupl
- durationcheck
# - errchkjson
# - errname
- errorlint
# - execinquery
# - exhaustive
# - exhaustruct
# - exportloopref
#- forcetypeassert
#- forbidigo
#- funlen
#- gci
# - gochecknoglobals
# - gochecknoinits
# - gocognit
# - goconst
# - gocritic
- gocyclo
- godot
# - godox
# - goerr113
# - gofmt
# - gofumpt
# - goheader
- goimports
# - golint
# - gomnd
# - gomoddirectives
# - gomodguard
# - goprintffuncname
- gosec
- gosimple
- govet
# - grouper
# - ifshort
- importas
- ineffassign
# - interfacebloat
# - interfacer
# - ireturn
# - lll
# - maintidx
# - makezero
- maligned
- misspell
# - nakedret
# - nestif
# - nilerr
# - nilnil
# - nlreturn
# - noctx
# - nolintlint
# - nonamedreturns
# - nosnakecase
# - nosprintfhostport
# - paralleltest
# - prealloc
# - predeclared
# - promlinter
- reassign
# - revive
# - rowserrcheck
# - scopelint
# - sqlclosecheck
# - staticcheck
# - structcheck
# - stylecheck
# - tagliatelle
# - tenv
# - testpackage
# - thelper
- tparallel
# - typecheck
- unconvert
- unparam
- unused
- usestdlibvars
# - varcheck
# - varnamelen
- wastedassign
- whitespace
# - wrapcheck
# - wsl
disable:
- errcheck

View File

@ -0,0 +1,31 @@
FROM golang:1.19.0-alpine3.15 AS builder
RUN apk update
RUN apk add git
WORKDIR /app
COPY go.mod ./
COPY go.sum ./
RUN go mod download
COPY . ./
RUN go build -o /app/mochi ./cmd
FROM alpine
WORKDIR /
COPY --from=builder /app/mochi .
# tcp
EXPOSE 1883
# websockets
EXPOSE 1882
# dashboard
EXPOSE 8080
ENTRYPOINT [ "/mochi" ]

View File

@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,404 @@
<p align="center">
![build status](https://github.com/mochi-co/mqtt/actions/workflows/build.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/mochi-co/mqtt/badge.svg?branch=master&v2)](https://coveralls.io/github/mochi-co/mqtt?branch=master)
[![Go Report Card](https://goreportcard.com/badge/github.com/mochi-co/mqtt)](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
[![Go Reference](https://pkg.go.dev/badge/github.com/mochi-co/mqtt.svg)](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/mochi-co/mqtt/issues)
</p>
# Mochi MQTT Broker
## The fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker server
Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.
### What is MQTT?
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
## What's new in Version 2.0.0?
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
Don't forget to use the new v2 import paths:
```go
import "github.com/mochi-co/mqtt/v2"
```
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
- User and MQTTv5 Packet Properties
- Topic Aliases
- Shared Subscriptions
- Subscription Options and Subscription Identifiers
- Message Expiry
- Client Session Expiry
- Send and Receive QoS Flow Control Quotas
- Server-side Disconnect and Auth Packets
- Will Delay Intervals
- Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.
- Developer-centric:
- Most core broker code is now exported and accessible, for total developer control.
- Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
- Direct Packet Injection using special inline client, or masquerade as existing clients.
- Performant and Stable:
- Our classic trie-based Topic-Subscription model.
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
- Over a thousand carefully considered unit test scenarios.
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
> There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system).
### Compatibility Notes
Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties).
Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits.
## Roadmap
- Please [open an issue](https://github.com/mochi-co/mqtt/issues) to request new features or event hooks!
- Cluster support.
- Enhanced Metrics support.
- File-based server configuration (supporting docker).
## Quick Start
### Running the Broker with Go
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the [cmd/main.go](cmd/main.go) entrypoint in the [cmd](cmd) folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners.
```
cd cmd
go build -o mqtt && ./mqtt
```
### Using Docker
A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server:
```sh
docker build -t mochi:latest .
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
```
## Developing with Mochi MQTT
### Importing as a package
Importing Mochi MQTT as a package requires just a few lines of code to get started.
``` go
import (
"log"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
// Create the new MQTT Server.
server := mqtt.New(nil)
// Allow all connections.
_ = server.AddHook(new(auth.AllowHook), nil)
// Create a TCP listener on a standard port.
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.Serve()
if err != nil {
log.Fatal(err)
}
}
```
Examples of running the broker with various configurations can be found in the [examples](examples) folder.
#### Network Listeners
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
| Listener | Usage |
| --- | --- |
| listeners.NewTCP | A TCP listener |
| listeners.NewUnixSock | A Unix Socket listener |
| listeners.NewNet | A net.Listener listener |
| listeners.NewWebsocket | A Websocket listener |
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
A `*listeners.Config` may be passed to configure TLS.
Examples of usage can be found in the [examples](examples) folder or [cmd/main.go](cmd/main.go).
### Server Options and Capabilities
A number of configurable options are available which can be used to alter the behaviour or restrict access to certain features in the server.
```go
server := mqtt.New(&mqtt.Options{
Capabilities: mqtt.Capabilities{
ClientNetWriteBufferSize: 4096,
ClientNetReadBufferSize: 4096,
MaximumSessionExpiryInterval: 3600,
Compatibilities: mqtt.Compatibilities{
ObscureNotAuthorized: true,
},
},
SysTopicResendInterval: 10,
})
```
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
## Event Hooks
A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools.
Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code.
| Type | Import | Info |
| -- | -- | -- |
| Access Control | [mochi-co/mqtt/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
| Access Control | [mochi-co/mqtt/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
| Persistence | [mochi-co/mqtt/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
| Persistence | [mochi-co/mqtt/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
| Persistence | [mochi-co/mqtt/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
| Debugging | [mochi-co/mqtt/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-co/mqtt/issues) and let everyone know!
### Access Control
#### Allow Hook
By default, Mochi MQTT uses a DENY-ALL access control rule. To allow connections, this must overwritten using an Access Control hook. The simplest of these hooks is the `auth.AllowAll` hook, which provides ALLOW-ALL rules to all connections, subscriptions, and publishing. It's also the simplest hook to use:
```go
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
```
> Don't do this if you are exposing your server to the internet or untrusted networks - it should really be used for development, testing, and debugging only.
#### Auth Ledger
The Auth Ledger hook provides a sophisticated mechanism for defining access rules in a struct format. Auth ledger rules come in two forms: Auth rules (connection), and ACL rules (publish subscribe).
Auth rules have 4 optional criteria and an assertion flag:
| Criteria | Usage |
| -- | -- |
| Client | client id of the connecting client |
| Username | username of the connecting client |
| Password | password of the connecting client |
| Remote | the remote address or ip of the client |
| Allow | true (allow this user) or false (deny this user) |
ACL rules have 3 optional criteria and an filter match:
| Criteria | Usage |
| -- | -- |
| Client | client id of the connecting client |
| Username | username of the connecting client |
| Remote | the remote address or ip of the client |
| Filters | an array of filters to match |
Rules are processed in index order (0,1,2,3), returning on the first matching rule. See [hooks/auth/ledger.go](hooks/auth/ledger.go) to review the structs.
```go
server := mqtt.New(nil)
err := server.AddHook(new(auth.Hook), &auth.Options{
Ledger: &auth.Ledger{
Auth: auth.AuthRules{ // Auth disallows all by default
{Username: "peach", Password: "password1", Allow: true},
{Username: "melon", Password: "password2", Allow: true},
{Remote: "127.0.0.1:*", Allow: true},
{Remote: "localhost:*", Allow: true},
},
ACL: auth.ACLRules{ // ACL allows all by default
{Remote: "127.0.0.1:*"}, // local superuser allow all
{
// user melon can read and write to their own topic
Username: "melon", Filters: auth.Filters{
"melon/#": auth.ReadWrite,
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
},
},
{
// Otherwise, no clients have publishing permissions
Filters: auth.Filters{
"#": auth.ReadOnly,
"updates/#": auth.Deny,
},
},
},
}
})
```
The ledger can also be stored as JSON or YAML and loaded using the Data field:
```go
err = server.AddHook(new(auth.Hook), &auth.Options{
Data: data, // build ledger from byte slice: yaml or json
})
```
See [examples/auth/encoded/main.go](examples/auth/encoded/main.go) for more information.
### Persistent Storage
#### Redis
A basic Redis storage hook is available which provides persistence for the broker. It can be added to the server in the same fashion as any other hook, with several options. It uses github.com/go-redis/redis/v8 under the hook, and is completely configurable through the Options value.
```go
err := server.AddHook(new(redis.Hook), &redis.Options{
Options: &rv8.Options{
Addr: "localhost:6379", // default redis address
Password: "", // your password
DB: 0, // your redis db
},
})
if err != nil {
log.Fatal(err)
}
```
For more information on how the redis hook works, or how to use it, see the [examples/persistence/redis/main.go](examples/persistence/redis/main.go) or [hooks/storage/redis](hooks/storage/redis) code.
#### Badger DB
There's also a BadgerDB storage hook if you prefer file based storage. It can be added and configured in much the same way as the other hooks (with somewhat less options).
```go
err := server.AddHook(new(badger.Hook), &badger.Options{
Path: badgerPath,
})
if err != nil {
log.Fatal(err)
}
```
For more information on how the badger hook works, or how to use it, see the [examples/persistence/badger/main.go](examples/persistence/badger/main.go) or [hooks/storage/badger](hooks/storage/badger) code.
There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go).
## Developing with Event Hooks
Many hooks are available for interacting with the broker and client lifecycle.
The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go).
> The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets.
| Function | Usage |
| -------------------------- | -- |
| OnStarted | Called when the server has successfully started.|
| OnStopped | Called when the server has successfully stopped. |
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
| OnSysInfoTick | Called when the $SYS topic values are published out. |
| OnConnect | Called when a new client connects |
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
| OnDisconnect | Called when a client is disconnected for any reason. |
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
| OnPacketSent | Called when a packet has been sent to a client. |
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
| OnPublish | Called when a client publishes a message. Allows packet modification. |
| OnPublished | Called when a client has published a message to subscribers. |
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
| OnRetainMessage | Called then a published message is retained. |
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
| OnQosComplete | Called when the Qos flow for a message has been completed. |
| OnQosDropped | Called when an inflight message expires before completion. |
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
| OnClientExpired | Called when a client session has expired and should be deleted. |
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
| StoredClients | Returns clients, eg. from a persistent store. |
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
### Direct Publish
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
```go
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
```
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
### Packet Injection
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
```go
cl := server.NewClient(nil, "local", "inline", true)
server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("scheduled message"),
})
```
> MQTT packets still need to be correctly formed, so refer our [the test packets catalogue](packets/tpackets.go) and [MQTTv5 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) for inspiration.
See the [hooks example](examples/hooks/main.go) to see this feature in action.
### Testing
#### Unit Tests
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
```
go run --cover ./...
```
#### Paho Interoperability Test
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the mqtt v5 and v3 tests with `python3 client_test5.py` from the _interoperability_ folder.
> Note that there are currently a number of outstanding issues regarding false negatives in the paho suite, and as such, certain compatibility modes are enabled in the `paho/main.go` example.
## Performance Benchmarks
Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQX, and others.
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
> The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers.
> Benchmarks are provided as a general performance expectation guideline only.
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.2.0 | 127,216 | 125,748 | 124,279 | 319,250 | 309,327 | 299,405 |
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.2.0 | 45,615 | 30,129 | 21,138 | 232,717 | 86,323 | 50,402 |
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
Million Message Challenge (hit the server with 1 million messages immediately):
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
| -- | -- | -- | -- | -- | -- | -- |
| Mochi v2.2.0 | 51,044 | 4,682 | 2,345 | 72,634 | 7,645 | 2,464 |
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
> Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software.
## Stargazers over time 🥰
[![Stargazers over time](https://starchart.cc/mochi-co/mqtt.svg)](https://starchart.cc/mochi-co/mqtt)
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues)
## Contributions
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.

View File

@ -0,0 +1,568 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-co/mqtt/v2/packets"
)
const (
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
)
// ReadFn is the function signature for the function used for reading and processing new packets.
type ReadFn func(*Client, packets.Packet) error
// Clients contains a map of the clients known by the broker.
type Clients struct {
internal map[string]*Client // clients known by the broker, keyed on client id.
sync.RWMutex
}
// NewClients returns an instance of Clients.
func NewClients() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
defer cl.Unlock()
cl.internal[val.ID] = val
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
m := map[string]*Client{}
for k, v := range cl.internal {
m[k] = v
}
return m
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
defer cl.RUnlock()
val, ok := cl.internal[id]
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
defer cl.RUnlock()
val := len(cl.internal)
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
defer cl.Unlock()
delete(cl.internal, id)
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
cl.RLock()
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
clients = append(clients, client)
}
}
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the clinet
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
}
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // client is an inline programmetic client
}
// ClientProperties contains the properties which define the client behaviour.
type ClientProperties struct {
Props packets.Properties
Will Will
Username []byte
ProtocolVersion byte
Clean bool
}
// Will contains the last will and testament details for a client connection.
type Will struct {
Payload []byte // -
User []packets.UserProperty // -
TopicName string // -
Flag uint32 // 0,1
WillDelayInterval uint32 // -
Qos byte // -
Retain bool // -
}
// State tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
done uint32 // atomic counter which indicates that the client has closed
outboundQty int32 // number of messages currently in the outbound queue
keepalive uint16 // the number of seconds the connection can wait
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
cl := &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
ops: o,
}
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReadWriter(
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
),
Remote: c.RemoteAddr().String(),
}
}
cl.refreshDeadline(cl.State.keepalive)
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for pk := range cl.State.outbound {
if err := cl.WritePacket(*pk); err != nil {
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
}
atomic.AddInt32(&cl.State.outboundQty, -1)
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
cl.Properties.ProtocolVersion = pk.ProtocolVersion
cl.Properties.Username = pk.Connect.Username
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
cl.Properties.Props.AssignedClientID = cl.ID
}
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
if pk.Connect.Keepalive > 0 {
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
Retain: pk.Connect.WillRetain,
Payload: pk.Connect.WillPayload,
TopicName: pk.Connect.WillTopic,
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
User: pk.Connect.WillProperties.User,
}
if pk.Properties.SessionExpiryIntervalFlag &&
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
}
if pk.Connect.WillFlag {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
cl.refreshDeadline(cl.State.keepalive)
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
// NextPacketID returns the next available (unused) packet id for the client.
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i
overflowed := false
for {
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
func (cl *Client) ResendInflightMessages(force bool) error {
if cl.State.Inflight.Len() == 0 {
return nil
}
for _, tk := range cl.State.Inflight.GetAll(false) {
if tk.FixedHeader.Type == packets.Publish {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
}
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosComplete(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
return nil
}
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
// Read reads incoming packets from the connected client and transforms them into
// packets to be handled by the packetHandler.
func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if atomic.LoadUint32(&cl.State.done) == 1 {
return nil
}
cl.refreshDeadline(cl.State.keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
if atomic.LoadUint32(&cl.State.done) == 1 {
return
}
cl.State.endOnce.Do(func() {
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
atomic.StoreUint32(&cl.State.done, 1)
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return atomic.LoadUint32(&cl.State.done) == 1
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
return ErrConnectionClosed
}
b, err := cl.Net.bconn.ReadByte()
if err != nil {
return err
}
err = fh.Decode(b)
if err != nil {
return err
}
var bu int
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
if err != nil {
return err
}
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
pk.FixedHeader = *fh
p := make([]byte, pk.FixedHeader.Remaining)
n, err := io.ReadFull(cl.Net.bconn, p)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Disconnect:
err = pk.DisconnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Auth:
err = pk.AuthDecode(px)
default:
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
}
if err != nil {
return pk, err
}
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if atomic.LoadUint32(&cl.State.done) == 1 {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
pk.ProtocolVersion = cl.Properties.ProtocolVersion
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
}
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
var err error
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
case packets.Auth:
err = pk.AuthEncode(buf)
default:
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
}
if err != nil {
return err
}
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
nb := net.Buffers{buf.Bytes()}
n, err := nb.WriteTo(cl.Net.Conn)
if err != nil {
return err
}
atomic.AddInt64(&cl.ops.info.BytesSent, n)
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
if pk.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
}
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
return err
}

View File

@ -0,0 +1,745 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var errClientStop = errors.New("test stop")
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
r, w = net.Pipe()
cl = newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: &logger,
options: &Options{
Capabilities: &Capabilities{
ReceiveMaximum: 10,
TopicAliasMaximum: 10000,
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
cl.ID = "mochi"
cl.State.Inflight.maximumSendQuota = 5
cl.State.Inflight.sendQuota = 5
cl.State.Inflight.maximumReceiveQuota = 10
cl.State.Inflight.receiveQuota = 10
cl.Properties.Props.TopicAliasMaximum = 0
cl.Properties.Props.RequestResponseInfo = 0x1
go cl.WriteLoop()
return
}
func TestNewInflights(t *testing.T) {
require.NotNil(t, NewInflights().internal)
}
func TestNewClients(t *testing.T) {
cl := NewClients()
require.NotNil(t, cl.internal)
}
func TestClientsAdd(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
}
func TestClientsGet(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
client, ok := cl.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, "t1", client.ID)
}
func TestClientsGetAll(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
cl.Add(&Client{ID: "t3"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Contains(t, cl.internal, "t3")
clients := cl.GetAll()
require.Len(t, clients, 3)
}
func TestClientsLen(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
cl.Add(&Client{ID: "t2"})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
require.Equal(t, 2, cl.Len())
}
func TestClientsDelete(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1"})
require.Contains(t, cl.internal, "t1")
cl.Delete("t1")
_, ok := cl.Get("t1")
require.Equal(t, false, ok)
require.Nil(t, cl.internal["t1"])
}
func TestClientsGetByListener(t *testing.T) {
cl := NewClients()
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
require.Contains(t, cl.internal, "t1")
require.Contains(t, cl.internal, "t2")
clients := cl.GetByListener("tcp1")
require.NotEmpty(t, clients)
require.Equal(t, 1, len(clients))
require.Equal(t, "tcp1", clients[0].Net.Listener)
}
func TestNewClient(t *testing.T) {
cl, _, _ := newTestClient()
require.NotNil(t, cl)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.NotNil(t, cl.ops.options.Capabilities)
require.False(t, cl.Net.Inline)
}
func TestClientParseConnect(t *testing.T) {
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillTopic: "lwt",
WillPayload: []byte("lol gg"),
WillQos: 1,
WillRetain: false,
},
Properties: packets.Properties{
ReceiveMaximum: uint16(5),
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
}
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
cl, _, _ := newTestClient()
pk := packets.Packet{
ProtocolVersion: 4,
Connect: packets.ConnectParams{
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
Clean: true,
Keepalive: 60,
ClientIdentifier: "mochi",
WillFlag: true,
WillProperties: packets.Properties{
WillDelayInterval: 200,
},
},
Properties: packets.Properties{
SessionExpiryInterval: 100,
SessionExpiryIntervalFlag: true,
},
}
cl.ParseConnect("tcp1", pk)
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
}
func TestClientParseConnectNoID(t *testing.T) {
cl, _, _ := newTestClient()
cl.ParseConnect("tcp1", packets.Packet{})
require.NotEmpty(t, cl.ID)
}
func TestClientNextPacketID(t *testing.T) {
cl, _, _ := newTestClient()
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(2), i)
}
func TestClientNextPacketIDInUse(t *testing.T) {
cl, _, _ := newTestClient()
// skip over 2
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(3), i)
// Skip over overflow
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
atomic.StoreUint32(&cl.State.packetID, 65534)
i, err = cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, uint32(1), i)
}
func TestClientNextPacketIDExhausted(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
}
i, err := cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, uint32(0), i)
}
func TestClientNextPacketIDOverflow(t *testing.T) {
cl, _, _ := newTestClient()
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
}
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
i, err := cl.NextPacketID()
require.NoError(t, err)
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
_, err = cl.NextPacketID()
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
}
func TestClientClearInflights(t *testing.T) {
cl, _, _ := newTestClient()
n := time.Now().Unix()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
require.Equal(t, 5, cl.State.Inflight.Len())
deleted := cl.ClearInflights(n, 4)
require.Len(t, deleted, 3)
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
require.Equal(t, 2, cl.State.Inflight.Len())
}
func TestClientResendInflightMessages(t *testing.T) {
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
cl, r, w := newTestClient()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
go func() {
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
time.Sleep(time.Millisecond)
w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, 0, cl.State.Inflight.Len())
require.Equal(t, pk1.RawBytes, buf)
}
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
cl, r, _ := newTestClient()
r.Close()
cl.State.Inflight.Set(*pk1.Packet)
require.Equal(t, 1, cl.State.Inflight.Len())
err := cl.ResendInflightMessages(true)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
cl, _, _ := newTestClient()
err := cl.ResendInflightMessages(true)
require.NoError(t, err)
}
func TestClientRefreshDeadline(t *testing.T) {
cl, _, _ := newTestClient()
cl.refreshDeadline(10)
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
}
func TestClientReadFixedHeader(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0x00})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
}
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
cl, r, _ := newTestClient()
cl.ops.options.Capabilities.MaximumPacketSize = 2
defer cl.Stop(errClientStop)
go func() {
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
}
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.Equal(t, io.EOF, err)
}
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
}
func TestClientReadOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 18, // Fixed header
0, 5, // Topic Name - LSB+MSB
'a', '/', 'b', '/', 'c', // Topic Name
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
}()
var pks []packets.Packet
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
pks = append(pks, pk)
return nil
})
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 2, len(pks))
require.Equal(t, []packets.Packet{
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 18,
},
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
},
{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
},
}, pks)
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
}
func TestClientReadDone(t *testing.T) {
cl, _, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.State.done = 1
o := make(chan error)
go func() {
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
return nil
})
}()
require.NoError(t, <-o)
}
func TestClientStop(t *testing.T) {
cl, _, _ := newTestClient()
cl.Stop(nil)
require.Equal(t, nil, cl.State.stopCause.Load())
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
require.Equal(t, uint32(1), cl.State.done)
require.Equal(t, nil, cl.StopCause())
}
func TestClientClosed(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.Closed())
cl.Stop(nil)
require.True(t, cl.Closed())
}
func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
})
r.Close()
}()
cl.Net.bconn = nil
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.Error(t, err)
require.ErrorIs(t, ErrConnectionClosed, err)
}
func TestClientReadReadHandlerErr(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5, // Topic Name - LSB+MSB
'd', '/', 'e', '/', 'f', // Topic Name
'y', 'e', 'a', 'h', // Payload
})
r.Close()
}()
err := cl.Read(func(cl *Client, pk packets.Packet) error {
return errors.New("test")
})
require.Error(t, err)
}
func TestClientReadReadPacketOK(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
packets.Publish << 4, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
pk, err := cl.ReadPacket(fh)
require.NoError(t, err)
require.NotNil(t, pk)
require.Equal(t, packets.Packet{
ProtocolVersion: cl.Properties.ProtocolVersion,
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
Remaining: 11,
},
TopicName: "d/e/f",
Payload: []byte("yeah"),
}, pk)
}
func TestClientReadPacket(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
for _, tx := range pkTable {
tt := tx // avoid data race
t.Run(tt.Desc, func(t *testing.T) {
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
go func() {
r.Write(tt.RawBytes)
}()
fh := new(packets.FixedHeader)
err := cl.ReadFixedHeader(fh)
require.NoError(t, err)
if tt.Packet.ProtocolVersion == 5 {
cl.Properties.ProtocolVersion = 5
} else {
cl.Properties.ProtocolVersion = 0
}
pk, err := cl.ReadPacket(fh)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
}
})
}
}
func TestClientReadPacketInvalidTypeError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
_, err := cl.ReadPacket(&packets.FixedHeader{})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid packet type")
}
func TestClientWritePacket(t *testing.T) {
for _, tt := range pkTable {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
o := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
o <- buf
}()
err := cl.WritePacket(*tt.Packet)
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
time.Sleep(2 * time.Millisecond)
cl.Net.Conn.Close()
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
cl.Stop(errClientStop)
time.Sleep(time.Millisecond * 1)
// The stop cause is either the test error, EOF, or a
// closed pipe, depending on which goroutine runs first.
err = cl.StopCause()
require.True(t,
errors.Is(err, errClientStop) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe))
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
if tt.Packet.FixedHeader.Type == packets.Publish {
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
}
}
}
func TestWriteClientOversizePacket(t *testing.T) {
cl, _, _ := newTestClient()
cl.Properties.Props.MaximumPacketSize = 2
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
err := cl.WritePacket(pk)
require.Error(t, err)
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
}
func TestClientReadPacketReadingError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Type: 0,
Remaining: 11,
})
require.Error(t, err)
}
func TestClientReadPacketReadUnknown(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
go func() {
r.Write([]byte{
0, 11, // Fixed header
0, 5,
'd', '/', 'e', '/', 'f',
'y', 'e', 'a', 'h',
})
r.Close()
}()
_, err := cl.ReadPacket(&packets.FixedHeader{
Remaining: 1,
})
require.Error(t, err)
}
func TestClientWritePacketWriteNoConn(t *testing.T) {
cl, _, _ := newTestClient()
cl.Stop(errClientStop)
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
require.Equal(t, ErrConnectionClosed, err)
}
func TestClientWritePacketWriteError(t *testing.T) {
cl, _, _ := newTestClient()
cl.Net.Conn.Close()
err := cl.WritePacket(*pkTable[1].Packet)
require.Error(t, err)
}
func TestClientWritePacketInvalidPacket(t *testing.T) {
cl, _, _ := newTestClient()
err := cl.WritePacket(packets.Packet{})
require.Error(t, err)
}
var (
pkTable = []packets.TPacketCase{
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
packets.TPacketData[packets.Puback].Get(packets.TPuback),
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
packets.TPacketData[packets.Suback].Get(packets.TSuback),
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
packets.TPacketData[packets.Auth].Get(packets.TAuth),
}
)

View File

@ -0,0 +1,52 @@
{
"auth": [
{
"username": "leandro",
"password": "leandro",
"allow": true
},
{
"username": "steve",
"password": "steve",
"allow": true
},
{
"username": "root",
"password": "root",
"allow": true
},
{
"remote": "*",
"allow": false
},
{
"remote": "*",
"allow": false
}
],
"acl": [
{
"remote": "*"
},
{
"username": "leandro",
"filters": {
"oktopus/+/agent/+": 1,
"oktopus/+/controller/+": 2
}
},
{
"username": "steve",
"filters": {
"oktopus/+/agent/+": 1,
"oktopus/+/controller/+": 2
}
},
{
"username": "root",
"filters": {
"#": 3
}
}
]
}

View File

@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"github.com/rs/zerolog"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
wsAddr := flag.String("ws", "", "network address for Websocket listener")
infoAddr := flag.String("info", "", "network address for web info dashboard listener")
path := flag.String("path", "", "path to data auth file")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(&mqtt.Options{
//Capabilities: &mqtt.Capabilities{
// ServerKeepAlive: 10000,
// ReceiveMaximum: math.MaxUint16,
// MaximumMessageExpiryInterval: math.MaxUint32,
// MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions
// MaximumClientWritesPending: 65536,
// MaximumPacketSize: 0,
// MaximumQos: 2,
//},
})
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
if *path != "" {
data, err := os.ReadFile(*path)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(auth.Hook), &auth.Options{
Data: data,
})
if err != nil {
log.Fatal(err)
}
} else {
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
}
if *tcpAddr != "" {
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
}
if *wsAddr != "" {
ws := listeners.NewWebsocket("ws1", *wsAddr, nil)
err := server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
}
if *infoAddr != "" {
stats := listeners.NewHTTPStats("stats", *infoAddr, nil, server.Info)
err := server.AddListener(stats)
if err != nil {
log.Fatal(err)
}
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
authRules := &auth.Ledger{
Auth: auth.AuthRules{ // Auth disallows all by default
{Username: "peach", Password: "password1", Allow: true},
{Username: "melon", Password: "password2", Allow: true},
{Remote: "127.0.0.1:*", Allow: true},
{Remote: "localhost:*", Allow: true},
},
ACL: auth.ACLRules{ // ACL allows all by default
{Remote: "127.0.0.1:*"}, // local superuser allow all
{
// user melon can read and write to their own topic
Username: "melon", Filters: auth.Filters{
"melon/#": auth.ReadWrite,
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
},
},
{
// Otherwise, no clients have publishing permissions
Filters: auth.Filters{
"#": auth.ReadOnly,
"updates/#": auth.Deny,
},
},
},
}
// you may also find this useful...
// d, _ := authRules.ToYAML()
// d, _ := authRules.ToJSON()
// fmt.Println(string(d))
server := mqtt.New(nil)
err := server.AddHook(new(auth.Hook), &auth.Options{
Ledger: authRules,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,52 @@
{
"auth": [
{
"username": "leandro",
"password": "leandro",
"allow": true
},
{
"username": "steve",
"password": "steve",
"allow": true
},
{
"username": "root",
"password": "root",
"allow": true
},
{
"remote": "*",
"allow": false
},
{
"remote": "*",
"allow": false
}
],
"acl": [
{
"remote": "*"
},
{
"username": "leandro",
"filters": {
"oktopus/+/agent/+": 1,
"oktopus/+/controller/+": 2
}
},
{
"username": "steve",
"filters": {
"oktopus/+/agent/+": 1,
"oktopus/+/controller/+": 2
}
},
{
"username": "root",
"filters": {
"#": 3
}
}
]
}

View File

@ -0,0 +1,21 @@
auth:
- username: peach
password: password1
allow: true
- username: melon
password: password2
allow: true
# - remote: 127.0.0.1:*
# allow: true
# - remote: localhost:*
# allow: true
acl:
# 0 = deny, 1 = read only, 2 = write only, 3 = read and write
- remote: 127.0.0.1:*
- username: melon
filters:
melon/#: 3
updates/#: 2
- filters:
'#': 1
updates/#: 0

View File

@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
// You can also run from top-level server.go folder:
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json
path := flag.String("path", "auth.yaml", "path to data auth file")
flag.Parse()
// Get ledger from yaml file
data, err := os.ReadFile(*path)
if err != nil {
log.Fatal(err)
}
server := mqtt.New(nil)
err = server.AddHook(new(auth.Hook), &auth.Options{
Data: data, // build ledger from byte slice, yaml or json
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
flag.Parse()
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,62 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/debug"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/rs/zerolog"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(debug.Hook), &debug.Options{
// ShowPacketData: true,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,143 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"bytes"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/mochi-co/mqtt/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
err = server.AddHook(new(ExampleHook), map[string]any{})
if err != nil {
log.Fatal(err)
}
// Start the server
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
// Demonstration of directly publishing messages to a topic via the
// `server.Publish` method. Subscribe to `direct/publish` using your
// MQTT client to see the messages.
go func() {
cl := server.NewClient(nil, "local", "inline", true)
for range time.Tick(time.Second * 1) {
err := server.InjectPacket(cl, packets.Packet{
FixedHeader: packets.FixedHeader{
Type: packets.Publish,
},
TopicName: "direct/publish",
Payload: []byte("injected scheduled message"),
})
if err != nil {
server.Log.Error().Err(err).Msg("server.InjectPacket")
}
server.Log.Info().Msgf("main.go injected packet to direct/publish")
}
}()
// There is also a shorthand convenience function, Publish, for easily sending
// publish packets if you are not concerned with creating your own packets.
go func() {
for range time.Tick(time.Second * 5) {
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
if err != nil {
server.Log.Error().Err(err).Msg("server.Publish")
}
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}
type ExampleHook struct {
mqtt.HookBase
}
func (h *ExampleHook) ID() string {
return "events-example"
}
func (h *ExampleHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnect,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnPublished,
mqtt.OnPublish,
}, []byte{b})
}
func (h *ExampleHook) Init(config any) error {
h.Log.Info().Msg("initialised")
return nil
}
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
}
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
}
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
}
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
}
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
pkx := pk
if string(pk.Payload) == "hello" {
pkx.Payload = []byte("hello world")
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
}
return pkx, nil
}
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
}

View File

@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"bytes"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/mochi-co/mqtt/v2/packets"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
server.Options.Capabilities.ServerKeepAlive = 60
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
_ = server.AddHook(new(pahoAuthHook), nil)
tcp := listeners.NewTCP("t1", ":1883", nil)
err := server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}
type pahoAuthHook struct {
mqtt.HookBase
}
func (h *pahoAuthHook) ID() string {
return "allow-all-auth"
}
func (h *pahoAuthHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return topic != "test/nosubscribe"
}

View File

@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/badger"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
badgerPath := ".badger"
defer os.RemoveAll(badgerPath) // remove the example badger files at the end
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(badger.Hook), &badger.Options{
Path: badgerPath,
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/bolt"
"github.com/mochi-co/mqtt/v2/listeners"
"go.etcd.io/bbolt"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
err := server.AddHook(new(bolt.Hook), &bolt.Options{
Path: "bolt.db",
Options: &bbolt.Options{
Timeout: 500 * time.Millisecond,
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
"github.com/mochi-co/mqtt/v2/listeners"
"github.com/rs/zerolog"
rv8 "github.com/go-redis/redis/v8"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
l := server.Log.Level(zerolog.DebugLevel)
server.Log = &l
err := server.AddHook(new(redis.Hook), &redis.Options{
Options: &rv8.Options{
Addr: "localhost:6379", // default redis address
Password: "", // your password
DB: 0, // your redis db
},
})
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
// An example of configuring various server options...
options := &mqtt.Options{
// InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
}
server := mqtt.New(options)
// For security reasons, the default implementation disallows all connections.
// If you want to allow all connections, you must specifically allow it.
err := server.AddHook(new(auth.AllowHook), nil)
if err != nil {
log.Fatal(err)
}
tcp := listeners.NewTCP("t1", ":1883", nil)
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,117 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"crypto/tls"
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
var (
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
// Optionally, if you want clients to authenticate only with certs issued by your CA,
// you might want to use something like this:
// certPool := x509.NewCertPool()
// _ = certPool.AppendCertsFromPEM(caCertPem)
// tlsConfig := &tls.Config{
// ClientCAs: certPool,
// ClientAuth: tls.RequireAndVerifyClientCert,
// }
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{
TLSConfig: tlsConfig,
})
err = server.AddListener(tcp)
if err != nil {
log.Fatal(err)
}
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{
TLSConfig: tlsConfig,
})
err = server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
TLSConfig: tlsConfig,
}, nil)
err = server.AddListener(stats)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package main
import (
"log"
"os"
"os/signal"
"syscall"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/auth"
"github.com/mochi-co/mqtt/v2/listeners"
)
func main() {
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
done <- true
}()
server := mqtt.New(nil)
_ = server.AddHook(new(auth.AllowHook), nil)
ws := listeners.NewWebsocket("ws1", ":1882", nil)
err := server.AddListener(ws)
if err != nil {
log.Fatal(err)
}
go func() {
err := server.Serve()
if err != nil {
log.Fatal(err)
}
}()
<-done
server.Log.Warn().Msg("caught signal, stopping...")
server.Close()
server.Log.Info().Msg("main.go finished")
}

View File

@ -0,0 +1,40 @@
module github.com/mochi-co/mqtt/v2
go 1.19
require (
github.com/alicebob/miniredis/v2 v2.23.0
github.com/asdine/storm v2.1.2+incompatible
github.com/asdine/storm/v3 v3.2.1
github.com/go-redis/redis/v8 v8.11.5
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.3.5
github.com/rs/xid v1.4.0
github.com/rs/zerolog v1.28.0
github.com/stretchr/testify v1.7.1
github.com/timshannon/badgerhold v1.0.0
go.etcd.io/bbolt v1.3.5
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/badger v1.6.0 // indirect
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/golang/protobuf v1.5.0 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.28.1 // indirect
)

View File

@ -0,0 +1,143 @@
github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M=
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo=
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/timshannon/badgerhold v1.0.0 h1:LtqnDRVP7294FWRiZCIfQa6Tt0bGmlzbO8c364QC2Y8=
github.com/timshannon/badgerhold v1.0.0/go.mod h1:Vv2Jj0PAfzqViEpGvJzLP8PY07x1iXLgKRuLY7bqPOE=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -0,0 +1,794 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co, thedevop
package mqtt
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
)
const (
SetOptions byte = iota
OnSysInfoTick
OnStarted
OnStopped
OnConnectAuthenticate
OnACLCheck
OnConnect
OnSessionEstablished
OnDisconnect
OnAuthPacket
OnPacketRead
OnPacketEncode
OnPacketSent
OnPacketProcessed
OnSubscribe
OnSubscribed
OnSelectSubscribers
OnUnsubscribe
OnUnsubscribed
OnPublish
OnPublished
OnPublishDropped
OnRetainMessage
OnQosPublish
OnQosComplete
OnQosDropped
OnWill
OnWillSent
OnClientExpired
OnRetainedExpired
StoredClients
StoredSubscriptions
StoredInflightMessages
StoredRetainedMessages
StoredSysInfo
)
var (
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
ErrInvalidConfigType = errors.New("invalid config type provided")
)
// Hook provides an interface of handlers for different events which occur
// during the lifecycle of the broker.
type Hook interface {
ID() string
Provides(b byte) bool
Init(config any) error
Stop() error
SetOpts(l *zerolog.Logger, o *HookOptions)
OnStarted()
OnStopped()
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
OnACLCheck(cl *Client, topic string, write bool) bool
OnSysInfoTick(*system.Info)
OnConnect(cl *Client, pk packets.Packet)
OnSessionEstablished(cl *Client, pk packets.Packet)
OnDisconnect(cl *Client, err error, expire bool)
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
OnUnsubscribed(cl *Client, pk packets.Packet)
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
OnPublished(cl *Client, pk packets.Packet)
OnPublishDropped(cl *Client, pk packets.Packet)
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
OnQosComplete(cl *Client, pk packets.Packet)
OnQosDropped(cl *Client, pk packets.Packet)
OnWill(cl *Client, will Will) (Will, error)
OnWillSent(cl *Client, pk packets.Packet)
OnClientExpired(cl *Client)
OnRetainedExpired(filter string)
StoredClients() ([]storage.Client, error)
StoredSubscriptions() ([]storage.Subscription, error)
StoredInflightMessages() ([]storage.Message, error)
StoredRetainedMessages() ([]storage.Message, error)
StoredSysInfo() (storage.SystemInfo, error)
}
// HookOptions contains values which are inherited from the server on initialisation.
type HookOptions struct {
Capabilities *Capabilities
}
// Hooks is a slice of Hook interfaces to be called in sequence.
type Hooks struct {
Log *zerolog.Logger // a logger for the hook (from the server)
internal atomic.Value // a slice of []Hook
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
qty int64 // the number of hooks in use
sync.Mutex // a mutex for locking when adding hooks
}
// Len returns the number of hooks added.
func (h *Hooks) Len() int64 {
return atomic.LoadInt64(&h.qty)
}
// Provides returns true if any one hook provides any of the requested hook methods.
func (h *Hooks) Provides(b ...byte) bool {
for _, hook := range h.GetAll() {
for _, hb := range b {
if hook.Provides(hb) {
return true
}
}
}
return false
}
// Add adds and initializes a new hook.
func (h *Hooks) Add(hook Hook, config any) error {
h.Lock()
defer h.Unlock()
err := hook.Init(config)
if err != nil {
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
}
i, ok := h.internal.Load().([]Hook)
if !ok {
i = []Hook{}
}
i = append(i, hook)
h.internal.Store(i)
atomic.AddInt64(&h.qty, 1)
h.wg.Add(1)
return nil
}
// GetAll returns a slice of all the hooks.
func (h *Hooks) GetAll() []Hook {
i, ok := h.internal.Load().([]Hook)
if !ok {
return []Hook{}
}
return i
}
// Stop indicates all attached hooks to gracefully end.
func (h *Hooks) Stop() {
go func() {
for _, hook := range h.GetAll() {
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
if err := hook.Stop(); err != nil {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
}
h.wg.Done()
}
}()
h.wg.Wait()
}
// OnSysInfoTick is called when the $SYS topic values are published out.
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSysInfoTick) {
hook.OnSysInfoTick(sys)
}
}
}
// OnStarted is called when the server has successfully started.
func (h *Hooks) OnStarted() {
for _, hook := range h.GetAll() {
if hook.Provides(OnStarted) {
hook.OnStarted()
}
}
}
// OnStopped is called when the server has successfully stopped.
func (h *Hooks) OnStopped() {
for _, hook := range h.GetAll() {
if hook.Provides(OnStopped) {
hook.OnStopped()
}
}
}
// OnConnect is called when a new client connects.
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnect) {
hook.OnConnect(cl, pk)
}
}
}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSessionEstablished) {
hook.OnSessionEstablished(cl, pk)
}
}
}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
for _, hook := range h.GetAll() {
if hook.Provides(OnDisconnect) {
hook.OnDisconnect(cl, err, expire)
}
}
}
// OnPacketRead is called when a packet is received from a client.
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketRead) {
npk, err := hook.OnPacketRead(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
// to create their own auth packet handling mechanisms.
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnAuthPacket) {
npk, err := hook.OnAuthPacket(cl, pkx)
if err != nil {
return pk, err
}
pkx = npk
}
}
return
}
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketEncode) {
pk = hook.OnPacketEncode(cl, pk)
}
}
return pk
}
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketProcessed) {
hook.OnPacketProcessed(cl, pk, err)
}
}
}
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
// containing the bytes sent.
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPacketSent) {
hook.OnPacketSent(cl, pk, b)
}
}
}
// OnSubscribe is called when a client subscribes to one or more filters. This method
// differs from OnSubscribed in that it allows you to modify the subscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribe) {
pk = hook.OnSubscribe(cl, pk)
}
}
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
for _, hook := range h.GetAll() {
if hook.Provides(OnSubscribed) {
hook.OnSubscribed(cl, pk, reasonCodes)
}
}
}
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
// shared subscription subscribers have been selected. This hook can be used to programmatically
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
// group in a custom manner (such as based on client id, ip, etc).
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
for _, hook := range h.GetAll() {
if hook.Provides(OnSelectSubscribers) {
subs = hook.OnSelectSubscribers(subs, pk)
}
}
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
// before the packet is processed. The return values of the hook methods are passed-through
// in the order the hooks were attached.
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribe) {
pk = hook.OnUnsubscribe(cl, pk)
}
}
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnUnsubscribed) {
hook.OnUnsubscribed(cl, pk)
}
}
}
// OnPublish is called when a client publishes a message. This method differs from OnPublished
// in that it allows you to modify you to modify the incoming packet before it is processed.
// The return values of the hook methods are passed-through in the order the hooks were attached.
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
pkx = pk
for _, hook := range h.GetAll() {
if hook.Provides(OnPublish) {
npk, err := hook.OnPublish(cl, pkx)
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
return pk, err
} else if err != nil {
continue
}
pkx = npk
}
}
return
}
// OnPublished is called when a client has published a message to subscribers.
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublished) {
hook.OnPublished(cl, pk)
}
}
}
// OnPublishDropped is called when a message to a client was dropped instead of delivered
// such as when a client is too slow to respond.
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnPublishDropped) {
hook.OnPublishDropped(cl, pk)
}
}
}
// OnRetainMessage is called then a published message is retained.
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainMessage) {
hook.OnRetainMessage(cl, pk, r)
}
}
}
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
// In other words, this method is called when a new inflight message is created or resent.
// It is typically used to store a new inflight message.
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosPublish) {
hook.OnQosPublish(cl, pk, sent, resends)
}
}
}
// OnQosComplete is called when the Qos flow for a message has been completed.
// In other words, when an inflight message is resolved.
// It is typically used to delete an inflight message from a store.
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosComplete) {
hook.OnQosComplete(cl, pk)
}
}
}
// OnQosDropped is called the Qos flow for a message expires. In other words, when
// an inflight message expires or is abandoned. It is typically used to delete an
// inflight message from a store.
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnQosDropped) {
hook.OnQosDropped(cl, pk)
}
}
}
// OnWill is called when a client disconnects and publishes an LWT message. This method
// differs from OnWillSent in that it allows you to modify the LWT message before it is
// published. The return values of the hook methods are passed-through in the order
// the hooks were attached.
func (h *Hooks) OnWill(cl *Client, will Will) Will {
for _, hook := range h.GetAll() {
if hook.Provides(OnWill) {
mlwt, err := hook.OnWill(cl, will)
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
continue
}
will = mlwt
}
}
return will
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
for _, hook := range h.GetAll() {
if hook.Provides(OnWillSent) {
hook.OnWillSent(cl, pk)
}
}
}
// OnClientExpired is called when a client session has expired and should be deleted.
func (h *Hooks) OnClientExpired(cl *Client) {
for _, hook := range h.GetAll() {
if hook.Provides(OnClientExpired) {
hook.OnClientExpired(cl)
}
}
}
// OnRetainedExpired is called when a retained message has expired and should be deleted.
func (h *Hooks) OnRetainedExpired(filter string) {
for _, hook := range h.GetAll() {
if hook.Provides(OnRetainedExpired) {
hook.OnRetainedExpired(filter)
}
}
}
// StoredClients returns all clients, e.g. from a persistent store, is used to
// populate the server clients list before start.
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClients) {
v, err := hook.StoredClients()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
// used to populate the server subscriptions list before start.
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSubscriptions) {
v, err := hook.StoredSubscriptions()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
// and is used to populate the restored clients with inflight messages before start.
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredInflightMessages) {
v, err := hook.StoredInflightMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
// and is used to populate the server topics with retained messages before start.
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredRetainedMessages) {
v, err := hook.StoredRetainedMessages()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
return v, err
}
if len(v) > 0 {
return v, nil
}
}
}
return
}
// StoredSysInfo returns a set of system info values.
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredSysInfo) {
v, err := hook.StoredSysInfo()
if err != nil {
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
return v, err
}
if v.Version != "" {
return v, nil
}
}
}
return
}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
// An implementation of this method MUST be used to allow or deny access to the
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check connecting users against an existing user database.
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
for _, hook := range h.GetAll() {
if hook.Provides(OnConnectAuthenticate) {
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
return true
}
}
}
return false
}
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
// An implementation of this method MUST be used to allow or deny access to the
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
// check publishing and subscribing users against an existing permissions or roles database.
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
for _, hook := range h.GetAll() {
if hook.Provides(OnACLCheck) {
if ok := hook.OnACLCheck(cl, topic, write); ok {
return true
}
}
}
return false
}
// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
Hook
Log *zerolog.Logger
Opts *HookOptions
}
// ID returns the ID of the hook.
func (h *HookBase) ID() string {
return "base"
}
// Provides indicates which methods a hook provides. The default is none - this method
// should be overridden by the embedding hook.
func (h *HookBase) Provides(b byte) bool {
return false
}
// Init performs any pre-start initializations for the hook, such as connecting to databases
// or opening files.
func (h *HookBase) Init(config any) error {
return nil
}
// SetOpts is called by the server to propagate internal values and generally should
// not be called manually.
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
// Stop is called to gracefully shutdown the hook.
func (h *HookBase) Stop() error {
return nil
}
// OnStarted is called when the server starts.
func (h *HookBase) OnStarted() {}
// OnStopped is called when the server stops.
func (h *HookBase) OnStopped() {}
// OnSysInfoTick is called when the server publishes system info.
func (h *HookBase) OnSysInfoTick(*system.Info) {}
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return false
}
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}
// OnConnect is called when a new client connects.
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
// OnDisconnect is called when a client is disconnected for any reason.
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
// OnAuthPacket is called when an auth packet is received from the client.
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketRead is called when a packet is received.
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnPacketSent is called immediately after a packet is written to a client.
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
// OnPacketProcessed is called immediately after a packet from a client is processed.
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
// OnSubscribe is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnSubscribed is called when a client subscribes to one or more filters.
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
// OnSelectSubscribers is called when selecting subscribers to receive a message.
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
return subs
}
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
return pk
}
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
// OnPublish is called when a client publishes a message.
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
return pk, nil
}
// OnPublished is called when a client has published a message to subscribers.
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
// OnRetainMessage is called then a published message is retained.
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
// OnQosDropped is called the Qos flow for a message expires.
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
// OnWill is called when a client disconnects and publishes an LWT message.
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
return will, nil
}
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
// OnClientExpired is called when a client session has expired.
func (h *HookBase) OnClientExpired(cl *Client) {}
// OnRetainedExpired is called when a retained message for a topic has expired.
func (h *HookBase) OnRetainedExpired(topic string) {}
// StoredClients returns all clients from a store.
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
return
}
// StoredSubscriptions returns all subcriptions from a store.
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
return
}
// StoredInflightMessages returns all inflight messages from a store.
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
return
}
// StoredRetainedMessages returns all retained messages from a store.
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
return
}
// StoredSysInfo returns a set of system info values.
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
return
}

View File

@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
)
// AllowHook is an authentication hook which allows connection access
// for all users and read and write access to all topics.
type AllowHook struct {
mqtt.HookBase
}
// ID returns the ID of the hook.
func (h *AllowHook) ID() string {
return "allow-all-auth"
}
// Provides indicates which hook methods this hook provides.
func (h *AllowHook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// OnConnectAuthenticate returns true/allowed for all requests.
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
return true
}
// OnACLCheck returns true/allowed for all checks.
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
return true
}

View File

@ -0,0 +1,35 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
func TestAllowAllID(t *testing.T) {
h := new(AllowHook)
require.Equal(t, "allow-all-auth", h.ID())
}
func TestAllowAllProvides(t *testing.T) {
h := new(AllowHook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublished))
}
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
}
func TestAllowAllOnACLCheck(t *testing.T) {
h := new(AllowHook)
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
}

View File

@ -0,0 +1,107 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"bytes"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
)
// Options contains the configuration/rules data for the auth ledger.
type Options struct {
Data []byte
Ledger *Ledger
}
// Hook is an authentication hook which implements an auth ledger.
type Hook struct {
mqtt.HookBase
config *Options
ledger *Ledger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "auth-ledger"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnConnectAuthenticate,
mqtt.OnACLCheck,
}, []byte{b})
}
// Init configures the hook with the auth ledger to be used for checking.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
var err error
if h.config.Ledger != nil {
h.ledger = h.config.Ledger
} else if len(h.config.Data) > 0 {
h.ledger = new(Ledger)
err = h.ledger.Unmarshal(h.config.Data)
}
if err != nil {
return err
}
if h.ledger == nil {
h.ledger = &Ledger{
Auth: AuthRules{},
ACL: ACLRules{},
}
}
h.Log.Info().
Int("authentication", len(h.ledger.Auth)).
Int("acl", len(h.ledger.ACL)).
Msg("loaded auth rules")
return nil
}
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
// in the auth ledger.
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
if _, ok := h.ledger.AuthOk(cl, pk); ok {
return true
}
h.Log.Info().
Str("username", string(pk.Connect.Username)).
Str("remote", cl.Net.Remote).
Msg("client failed authentication check")
return false
}
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
// or publish to a given topic.
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
return true
}
h.Log.Debug().
Str("client", cl.ID).
Str("username", string(cl.Properties.Username)).
Str("topic", topic).
Msg("client failed allowed ACL check")
return false
}

View File

@ -0,0 +1,213 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"os"
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
// func teardown(t *testing.T, path string, h *Hook) {
// h.Stop()
// }
func TestBasicID(t *testing.T) {
h := new(Hook)
require.Equal(t, "auth-ledger", h.ID())
}
func TestBasicProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnACLCheck))
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
require.False(t, h.Provides(mqtt.OnPublish))
}
func TestBasicInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestBasicInitDefaultConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
}
func TestBasicInitWithLedgerPointer(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := &Ledger{
Auth: []AuthRule{
{
Remote: "127.0.0.1",
Allow: true,
},
},
ACL: []ACLRule{
{
Remote: "127.0.0.1",
Filters: Filters{
"#": ReadWrite,
},
},
},
}
err := h.Init(&Options{
Ledger: ln,
})
require.NoError(t, err)
require.Same(t, ln, h.ledger)
}
func TestBasicInitWithLedgerJSON(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerJSON,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerYAML(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: ledgerYAML,
})
require.NoError(t, err)
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
}
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Nil(t, h.ledger)
err := h.Init(&Options{
Data: []byte("fdsfdsafasd"),
})
require.Error(t, err)
}
func TestOnConnectAuthenticate(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
))
require.False(t, h.OnConnectAuthenticate(
&mqtt.Client{},
packets.Packet{},
))
}
func TestOnACL(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
ln := new(Ledger)
ln.Auth = checkLedger.Auth
ln.ACL = checkLedger.ACL
err := h.Init(
&Options{
Ledger: ln,
},
)
require.NoError(t, err)
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"mochi/info",
true,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"d/j/f",
true,
))
require.True(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
false,
))
require.False(t, h.OnACLCheck(
&mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
"readonly",
true,
))
}

View File

@ -0,0 +1,231 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"encoding/json"
"strings"
"sync"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"gopkg.in/yaml.v3"
)
const (
Deny Access = iota // user cannot access the topic
ReadOnly // user can only subscribe to the topic
WriteOnly // user can only publish to the topic
ReadWrite // user can both publish and subscribe to the topic
)
// Access determines the read/write privileges for an ACL rule.
type Access byte
// Users contains a map of access rules for specific users, keyed on username.
type Users map[string]UserRule
// UserRule defines a set of access rules for a specific user.
type UserRule struct {
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
}
// AuthRules defines generic access rules applicable to all users.
type AuthRules []AuthRule
type AuthRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
}
// ACLRules defines generic topic or filter access rules applicable to all users.
type ACLRules []ACLRule
// ACLRule defines access rules for a specific topic or filter.
type ACLRule struct {
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
}
// Filters is a map of Access rules keyed on filter.
type Filters map[RString]Access
// RString is a rule value string.
type RString string
// Matches returns true if the rule matches a given string.
func (r RString) Matches(a string) bool {
rr := string(r)
if r == "" || r == "*" || a == rr {
return true
}
i := strings.Index(rr, "*")
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
return true
}
return false
}
// FilterMatches returns true if a filter matches a topic rule.
func (f RString) FilterMatches(a string) bool {
_, ok := MatchTopic(string(f), a)
return ok
}
// MatchTopic checks if a given topic matches a filter, accounting for filter
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
filterParts := strings.Split(filter, "/")
topicParts := strings.Split(topic, "/")
elements = make([]string, 0)
for i := 0; i < len(filterParts); i++ {
if i >= len(topicParts) {
matched = false
return
}
if filterParts[i] == "+" {
elements = append(elements, topicParts[i])
continue
}
if filterParts[i] == "#" {
matched = true
elements = append(elements, strings.Join(topicParts[i:], "/"))
return
}
if filterParts[i] != topicParts[i] {
matched = false
return
}
}
return elements, true
}
// Ledger is an auth ledger containing access rules for users and topics.
type Ledger struct {
sync.Mutex `json:"-" yaml:"-"`
Users Users `json:"users" yaml:"users"`
Auth AuthRules `json:"auth" yaml:"auth"`
ACL ACLRules `json:"acl" yaml:"acl"`
}
// Update updates the internal values of the ledger.
func (l *Ledger) Update(ln *Ledger) {
l.Lock()
defer l.Unlock()
l.Auth = ln.Auth
l.ACL = ln.ACL
}
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
u.Password != "" &&
u.Password == RString(pk.Connect.Password) {
return 0, !u.Disallow
}
}
// If there's no users map, or no user was found, attempt to find a matching
// rule (which may also contain a user).
for n, rule := range l.Auth {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Password.Matches(string(pk.Connect.Password)) &&
rule.Remote.Matches(cl.Net.Remote) {
return n, rule.Allow
}
}
return 0, false
}
// ACLOk returns true if the rules indicate the user is allowed to read or write to
// a specific filter or topic respectively, based on the write bool.
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
// If the users map is set, always check for a predefined user first instead
// of iterating through global rules.
if l.Users != nil {
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
for filter, access := range u.ACL {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
for n, rule := range l.ACL {
if rule.Client.Matches(cl.ID) &&
rule.Username.Matches(string(cl.Properties.Username)) &&
rule.Remote.Matches(cl.Net.Remote) {
if len(rule.Filters) == 0 {
return n, true
}
for filter, access := range rule.Filters {
if filter.FilterMatches(topic) {
if !write && (access == ReadOnly || access == ReadWrite) {
return n, true
} else if write && (access == WriteOnly || access == ReadWrite) {
return n, true
} else {
return n, false
}
}
}
}
}
return 0, true
}
// ToJSON encodes the values into a JSON string.
func (l *Ledger) ToJSON() (data []byte, err error) {
return json.Marshal(l)
}
// ToYAML encodes the values into a YAML string.
func (l *Ledger) ToYAML() (data []byte, err error) {
return yaml.Marshal(l)
}
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
func (l *Ledger) Unmarshal(data []byte) error {
l.Lock()
defer l.Unlock()
if len(data) == 0 {
return nil
}
if data[0] == '{' {
return json.Unmarshal(data, l)
}
return yaml.Unmarshal(data, &l)
}

View File

@ -0,0 +1,610 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package auth
import (
"testing"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
var (
checkLedger = Ledger{
Users: Users{ // users are allowed by default
"mochi-co": {
Password: "melon",
ACL: Filters{
"d/+/f": Deny,
"mochi-co/#": ReadWrite,
"readonly": ReadOnly,
},
},
"suspended-username": {
Password: "any",
Disallow: true,
},
"mochi": { // ACL only, will defer to AuthRules for authentication
ACL: Filters{
"special/mochi": ReadOnly,
"secret/mochi": Deny,
"ignored": ReadWrite,
},
},
},
Auth: AuthRules{
{Username: "banned-user"}, // never allow specific username
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
{Remote: "123.123.123.123"}, // disallow any from specific address
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
},
ACL: ACLRules{
{
Username: "mochi", // allow matching user/pass
Filters: Filters{
"a/b/c": Deny,
"d/+/f": Deny,
"mochi/#": ReadWrite,
"updates/#": WriteOnly,
"readonly": ReadOnly,
"ignored": Deny,
},
},
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
{Remote: "001.002.003.004"}, // Allow all with no filter
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
},
}
)
func TestRStringMatches(t *testing.T) {
require.True(t, RString("*").Matches("any"))
require.True(t, RString("*").Matches(""))
require.True(t, RString("").Matches("any"))
require.True(t, RString("").Matches(""))
require.False(t, RString("no").Matches("any"))
require.False(t, RString("no").Matches(""))
}
func TestCanAuthenticate(t *testing.T) {
tt := []struct {
desc string
client *mqtt.Client
pk packets.Packet
n int
ok bool
}{
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "allow all local 127.0.0.1",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: true,
n: 1,
},
{
desc: "allow username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 5,
},
{
desc: "deny username/password",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
ok: false,
n: 0,
},
{
desc: "deny client from address",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("not-mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.144.155.166",
},
},
pk: packets.Packet{},
ok: false,
n: 3,
},
{
desc: "allow remote wildcard",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
Net: mqtt.ClientConnection{
Remote: "111.0.0.1",
},
},
pk: packets.Packet{},
ok: true,
n: 4,
},
{
desc: "never allow username",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("banned-user"),
},
Net: mqtt.ClientConnection{
Remote: "127.0.0.1",
},
},
pk: packets.Packet{},
ok: false,
n: 0,
},
{
desc: "matching user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi-co"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
ok: true,
n: 0,
},
{
desc: "never user in users",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("suspended-user"),
},
},
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
ok: false,
n: 0,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.AuthOk(d.client, d.pk)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestCanACL(t *testing.T) {
tt := []struct {
client *mqtt.Client
desc string
topic string
n int
write bool
ok bool
}{
{
desc: "allow normal write on any other filter",
client: &mqtt.Client{},
topic: "default/acl/write/access",
write: true,
ok: true,
},
{
desc: "allow normal read on any other filter",
client: &mqtt.Client{},
topic: "default/acl/read/access",
write: false,
ok: true,
},
{
desc: "deny user on literal filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "a/b/c",
},
{
desc: "deny user on partial filter",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "d/j/f",
},
{
desc: "allow read/write to user path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "mochi/read/write",
write: true,
ok: true,
},
{
desc: "deny read on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/no/reading",
write: false,
ok: false,
},
{
desc: "deny read on write-only path ext",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: false,
ok: false,
},
{
desc: "allow read on not-acl path (no #)",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates",
write: false,
ok: true,
},
{
desc: "allow write on write-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "updates/mochi",
write: true,
ok: true,
},
{
desc: "deny write on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: true,
ok: false,
},
{
desc: "allow read on read-only path",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "readonly",
write: false,
ok: true,
},
{
desc: "allow $sys access to localhost",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "localhost",
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 1,
},
{
desc: "allow $sys access to admin",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("admin"),
},
},
topic: "$SYS/test",
write: false,
ok: true,
n: 2,
},
{
desc: "deny $sys access to all others",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "$SYS/test",
write: false,
ok: false,
n: 4,
},
{
desc: "allow all with no filter",
client: &mqtt.Client{
Net: mqtt.ClientConnection{
Remote: "001.002.003.004",
},
},
topic: "any/path",
write: true,
ok: true,
n: 3,
},
{
desc: "use users embedded acl deny",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "secret/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl any",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "any/mochi",
write: true,
ok: true,
},
{
desc: "use users embedded acl write on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: true,
ok: false,
},
{
desc: "use users embedded acl read on read-only",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "special/mochi",
write: false,
ok: true,
},
{
desc: "preference users embedded acl",
client: &mqtt.Client{
Properties: mqtt.ClientProperties{
Username: []byte("mochi"),
},
},
topic: "ignored",
write: true,
ok: true,
},
}
for _, d := range tt {
t.Run(d.desc, func(t *testing.T) {
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
require.Equal(t, d.n, n)
require.Equal(t, d.ok, ok)
})
}
}
func TestMatchTopic(t *testing.T) {
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "d"}, el)
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
require.True(t, matched)
require.Equal(t, []string{"b", "c", "d"}, el)
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
require.True(t, matched)
require.Equal(t, []string{"things/yeah"}, el)
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
require.True(t, matched)
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
el, matched = MatchTopic("test", "test")
require.True(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("things/stuff//", "things/stuff/")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic("t", "t2")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
el, matched = MatchTopic(" ", " ")
require.False(t, matched)
require.Equal(t, make([]string, 0), el)
}
var (
ledgerStruct = Ledger{
Users: Users{
"mochi": {
Password: "peach",
ACL: Filters{
"readonly": ReadOnly,
"deny": Deny,
},
},
},
Auth: AuthRules{
{
Client: "*",
Username: "mochi-co",
Password: "melon",
Remote: "192.168.1.*",
Allow: true,
},
},
ACL: ACLRules{
{
Client: "*",
Username: "mochi-co",
Remote: "127.*",
Filters: Filters{
"readonly": ReadOnly,
"writeonly": WriteOnly,
"readwrite": ReadWrite,
"deny": Deny,
},
},
},
}
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
ledgerYAML = []byte(`users:
mochi:
password: peach
acl:
deny: 0
readonly: 1
auth:
- client: '*'
username: mochi-co
remote: 192.168.1.*
password: melon
allow: true
acl:
- client: '*'
username: mochi-co
remote: 127.*
filters:
deny: 0
readonly: 1
readwrite: 3
writeonly: 2
`)
)
func TestLedgerUpdate(t *testing.T) {
old := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
},
}
new := &Ledger{
Auth: AuthRules{
{Remote: "127.0.0.1", Allow: true},
{Remote: "192.168.*", Allow: true},
},
}
old.Update(new)
require.Len(t, old.Auth, 2)
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
require.NotSame(t, new, old)
}
func TestLedgerToJSON(t *testing.T) {
data, err := ledgerStruct.ToJSON()
require.NoError(t, err)
require.Equal(t, ledgerJSON, data)
}
func TestLedgerToYAML(t *testing.T) {
data, err := ledgerStruct.ToYAML()
require.NoError(t, err)
require.Equal(t, ledgerYAML, data)
}
func TestLedgerUnmarshalFromYAML(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerYAML)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalFromJSON(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal(ledgerJSON)
require.NoError(t, err)
require.Equal(t, &ledgerStruct, l)
require.NotSame(t, l, &ledgerStruct)
}
func TestLedgerUnmarshalNil(t *testing.T) {
l := new(Ledger)
err := l.Unmarshal([]byte{})
require.NoError(t, err)
require.Equal(t, new(Ledger), l)
}

View File

@ -0,0 +1,250 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package debug
import (
"strings"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/rs/zerolog"
)
// Options contains configuration settings for the debug output.
type Options struct {
ShowPacketData bool // include decoded packet data (default false)
ShowPings bool // show ping requests and responses (default false)
ShowPasswords bool // show connecting user passwords (default false)
}
// Hook is a debugging hook which logs additional low-level information from the server.
type Hook struct {
mqtt.HookBase
config *Options
Log *zerolog.Logger
}
// ID returns the ID of the hook.
func (h *Hook) ID() string {
return "debug"
}
// Provides indicates that this hook provides all methods.
func (h *Hook) Provides(b byte) bool {
return true
}
// Init is called when the hook is initialized.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
return nil
}
// SetOpts is called when the hook receives inheritable server parameters.
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
h.Log = l
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
}
// Stop is called when the hook is stopped.
func (h *Hook) Stop() error {
h.Log.Debug().Str("method", "Stop").Send()
return nil
}
// OnStarted is called when the server starts.
func (h *Hook) OnStarted() {
h.Log.Debug().Str("method", "OnStarted").Send()
}
// OnStopped is called when the server stops.
func (h *Hook) OnStopped() {
h.Log.Debug().Str("method", "OnStopped").Send()
}
// OnPacketRead is called when a new packet is received from a client.
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return pk, nil
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
return pk, nil
}
// OnPacketSent is called when a packet is sent to a client.
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
return
}
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
}
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
}
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
}
// OnQosComplete is called when the Qos flow for a message has been completed.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
}
// OnQosDropped is called the Qos flow for a message expires.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
}
// OnLWTSent is called when a will message has been issued from a disconnecting client.
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
}
// OnRetainedExpired is called when the server clears expired retained messages.
func (h *Hook) OnRetainedExpired(filter string) {
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
}
// OnClientExpired is called when the server clears an expired client.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
}
// StoredClients is called when the server restores clients from a store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
return v, nil
}
// StoredClients is called when the server restores subscriptions from a store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
h.Log.Debug().
Str("method", "StoredSubscriptions").
Send()
return v, nil
}
// StoredClients is called when the server restores retained messages from a store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredRetainedMessages").
Send()
return v, nil
}
// StoredClients is called when the server restores inflight messages from a store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
h.Log.Debug().
Str("method", "StoredInflightMessages").
Send()
return v, nil
}
// StoredClients is called when the server restores system info from a store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
h.Log.Debug().
Str("method", "StoredClients").
Send()
return v, nil
}
// packetMeta adds additional type-specific metadata to the debug logs.
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
m := map[string]any{}
switch pk.FixedHeader.Type {
case packets.Connect:
m["id"] = pk.Connect.ClientIdentifier
m["clean"] = pk.Connect.Clean
m["keepalive"] = pk.Connect.Keepalive
m["version"] = pk.ProtocolVersion
m["username"] = string(pk.Connect.Username)
if h.config.ShowPasswords {
m["password"] = string(pk.Connect.Password)
}
if pk.Connect.WillFlag {
m["will_topic"] = pk.Connect.WillTopic
m["will_payload"] = string(pk.Connect.WillPayload)
}
case packets.Publish:
m["topic"] = pk.TopicName
m["payload"] = string(pk.Payload)
m["raw"] = pk.Payload
m["qos"] = pk.FixedHeader.Qos
m["id"] = pk.PacketID
case packets.Connack:
fallthrough
case packets.Disconnect:
fallthrough
case packets.Puback:
fallthrough
case packets.Pubrec:
fallthrough
case packets.Pubrel:
fallthrough
case packets.Pubcomp:
m["id"] = pk.PacketID
m["reason"] = int(pk.ReasonCode)
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
m["reason_string"] = pk.Properties.ReasonString
}
case packets.Subscribe:
f := map[string]int{}
ids := map[string]int{}
for _, v := range pk.Filters {
f[v.Filter] = int(v.Qos)
ids[v.Filter] = v.Identifier
}
m["filters"] = f
m["subids"] = f
case packets.Unsubscribe:
f := []string{}
for _, v := range pk.Filters {
f = append(f, v.Filter)
}
m["filters"] = f
case packets.Suback:
fallthrough
case packets.Unsuback:
r := []int{}
for _, v := range pk.ReasonCodes {
r = append(r, int(v))
}
m["reasons"] = r
case packets.Auth:
// tbd
}
if h.config.ShowPacketData {
m["packet"] = pk
}
return m
}

View File

@ -0,0 +1,473 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"bytes"
"errors"
"strings"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/timshannon/badgerhold"
)
const (
// defaultDbFile is the default file path for the badger db file.
defaultDbFile = ".badger"
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the BadgerDB instance.
type Options struct {
Options *badgerhold.Options
Path string
}
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the BadgerDB instance.
db *badgerhold.Store // the BadgerDB instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "badger-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the badger instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Path == "" {
h.config.Path = defaultDbFile
}
options := badgerhold.DefaultOptions
options.Dir = h.config.Path
options.ValueDir = h.config.Path
options.Logger = h
var err error
h.db, err = badgerhold.Open(options)
if err != nil {
return err
}
return nil
}
// Stop closes the badger instance.
func (h *Hook) Stop() error {
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
}
}
// OnDisconnect removes a client from the store if their session has expired.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
h.updateClient(cl)
if !expire {
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
if err != nil {
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
PacketID: pk.PacketID,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.Upsert(in.ID, in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(retainedKey(filter), new(storage.Message))
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.Delete(clientKey(cl), new(storage.Client))
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.ClientKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.SubscriptionKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.RetainedKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Get(storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
return
}
return v, nil
}
// Errorf satisfies the badger interface for an error logger.
func (h *Hook) Errorf(m string, v ...interface{}) {
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Warningf satisfies the badger interface for a warning logger.
func (h *Hook) Warningf(m string, v ...interface{}) {
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Infof satisfies the badger interface for an info logger.
func (h *Hook) Infof(m string, v ...interface{}) {
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}
// Debugf satisfies the badger interface for a debug logger.
func (h *Hook) Debugf(m string, v ...interface{}) {
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
}

View File

@ -0,0 +1,681 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package badger
import (
"os"
"strings"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/timshannon/badgerhold"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
h.db.Badger().Close()
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "badger-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.db.Get(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.db.Get(clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.db.Get(clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, badgerhold.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.db.Upsert(clientKey, &storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.db.Get(clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.db.Get(clientKey, r)
require.Error(t, err)
require.ErrorIs(t, badgerhold.ErrNotFound, err)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.db.Get(clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, badgerhold.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.db.Get(retainedKey(pk.TopicName), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.db.Upsert(m.ID, m)
require.NoError(t, err)
r := new(storage.Message)
err = h.db.Get(m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.db.Get(m.ID, r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnRetainExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.db.Get(inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.db.Get(inflightKey(client, pk), r)
require.Error(t, err)
require.ErrorIs(t, err, badgerhold.ErrNotFound)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.db.Get(storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.db.Upsert("cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Upsert("cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Upsert("cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.db.Upsert("sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Upsert("sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Upsert("sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Upsert(storage.SysInfoKey, &storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestErrorf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Errorf("test", 1, 2, 3)
}
func TestWarningf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Warningf("test", 1, 2, 3)
}
func TestInfof(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Infof("test", 1, 2, 3)
}
func TestDebugf(t *testing.T) {
// coverage: one day check log hook
h := new(Hook)
h.SetOpts(&logger, nil)
h.Debugf("test", 1, 2, 3)
}

View File

@ -0,0 +1,474 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
package bolt
import (
"bytes"
"errors"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
sgob "github.com/asdine/storm/codec/gob"
"github.com/asdine/storm/v3"
"go.etcd.io/bbolt"
)
const (
// defaultDbFile is the default file path for the boltdb file.
defaultDbFile = "bolt.db"
// defaultTimeout is the default time to hold a connection to the file.
defaultTimeout = 250 * time.Millisecond
)
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return storage.RetainedKey + "_" + topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
Options *bbolt.Options
Path string
}
// Hook is a persistent storage hook based using boltdb file store as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for configuring the boltdb instance.
db *storm.DB // the boltdb instance.
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "bolt-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnWillSent,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// Init initializes and connects to the boltdb instance.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
if config == nil {
config = new(Options)
}
h.config = config.(*Options)
if h.config.Options == nil {
h.config.Options = &bbolt.Options{
Timeout: defaultTimeout,
}
}
if h.config.Path == "" {
h.config.Path = defaultDbFile
}
var err error
h.db, err = storm.Open(h.config.Path, storm.BoltOptions(0600, h.config.Options), storm.Codec(sgob.Codec))
if err != nil {
return err
}
return nil
}
// Stop closes the boltdb instance.
func (h *Hook) Stop() error {
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.DeleteStruct(&storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
Msg("failed to delete client")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.DeleteStruct(&storage.Message{
ID: retainedKey(pk.TopicName),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", retainedKey(pk.TopicName)).
Msg("failed to delete retained publish")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save retained publish data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Str("client", cl.ID).
Interface("data", in).
Msg("failed to save qos inflight data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Message{
ID: inflightKey(cl, pk),
})
if err != nil {
h.Log.Error().Err(err).
Str("id", inflightKey(cl, pk)).
Msg("failed to delete inflight data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.Save(in)
if err != nil {
h.Log.Error().Err(err).
Interface("data", in).
Msg("failed to save $SYS data")
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
if err != nil && !errors.Is(err, storm.ErrNotFound) {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.ClientKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.SubscriptionKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.RetainedKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.Find("T", storage.InflightKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err = h.db.One("ID", storage.SysInfoKey, &v)
if err != nil && !errors.Is(err, storm.ErrNotFound) {
return
}
return v, nil
}

View File

@ -0,0 +1,717 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package bolt
import (
"os"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/asdine/storm/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func teardown(t *testing.T, path string, h *Hook) {
h.Stop()
err := os.Remove(path)
require.NoError(t, err)
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, storage.InflightKey+"_cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
require.Equal(t, "bolt-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitUseDefaults(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
require.Equal(t, defaultDbFile, h.config.Path)
}
func TestInitBadPath(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Path: "..",
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
err = h.db.One("ID", clientKey(client), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
err = h.db.One("ID", clientKey(client), r2)
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
err = h.db.One("ID", clientKey(client), r3)
require.Error(t, err)
require.ErrorIs(t, storm.ErrNotFound, err)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
err = h.db.One("ID", clientKey(client), r)
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err = h.db.Save(&storage.Client{ID: cl.ID})
require.NoError(t, err)
r := new(storage.Client)
err = h.db.One("ID", clientKey, r)
require.NoError(t, err)
require.Equal(t, cl.ID, r.ID)
h.OnClientExpired(cl)
err = h.db.One("ID", clientKey, r)
require.Error(t, err)
require.ErrorIs(t, storm.ErrNotFound, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnSubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
err = h.db.One("ID", retainedKey(pk.TopicName), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpired(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err = h.db.Save(m)
require.NoError(t, err)
r := new(storage.Message)
err = h.db.One("ID", m.ID, r)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
err = h.db.One("ID", m.ID, r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
err = h.db.One("ID", inflightKey(client, pk), r)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
err = h.db.One("ID", inflightKey(client, pk), r)
require.Error(t, err)
require.Equal(t, storm.ErrNotFound, err)
}
func TestOnQosPublishNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
err = h.db.One("ID", storage.SysInfoKey, r)
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with clients
err = h.db.Save(&storage.Client{ID: "cl1", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Save(&storage.Client{ID: "cl2", T: storage.ClientKey})
require.NoError(t, err)
err = h.db.Save(&storage.Client{ID: "cl3", T: storage.ClientKey})
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with subscriptions
err = h.db.Save(&storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Save(&storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
require.NoError(t, err)
err = h.db.Save(&storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m2", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m3", T: storage.RetainedKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with messages
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
require.NoError(t, err)
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h.config.Path, h)
// populate with sys info
err = h.db.Save(&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
})
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
teardown(t, h.config.Path, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

View File

@ -0,0 +1,513 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
redis "github.com/go-redis/redis/v8"
)
// defaultAddr is the default address to the redis service.
const defaultAddr = "localhost:6379"
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
const defaultHPrefix = "mochi-"
// clientKey returns a primary key for a client.
func clientKey(cl *mqtt.Client) string {
return cl.ID
}
// subscriptionKey returns a primary key for a subscription.
func subscriptionKey(cl *mqtt.Client, filter string) string {
return cl.ID + ":" + filter
}
// retainedKey returns a primary key for a retained message.
func retainedKey(topic string) string {
return topic
}
// inflightKey returns a primary key for an inflight message.
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
return cl.ID + ":" + pk.FormatID()
}
// sysInfoKey returns a primary key for system info.
func sysInfoKey() string {
return storage.SysInfoKey
}
// Options contains configuration settings for the bolt instance.
type Options struct {
HPrefix string
Options *redis.Options
}
// Hook is a persistent storage hook based using Redis as a backend.
type Hook struct {
mqtt.HookBase
config *Options // options for connecting to the Redis instance.
db *redis.Client // the Redis instance
ctx context.Context // a context for the connection
}
// ID returns the id of the hook.
func (h *Hook) ID() string {
return "redis-db"
}
// Provides indicates which hook methods this hook provides.
func (h *Hook) Provides(b byte) bool {
return bytes.Contains([]byte{
mqtt.OnSessionEstablished,
mqtt.OnDisconnect,
mqtt.OnSubscribed,
mqtt.OnUnsubscribed,
mqtt.OnRetainMessage,
mqtt.OnQosPublish,
mqtt.OnQosComplete,
mqtt.OnQosDropped,
mqtt.OnWillSent,
mqtt.OnSysInfoTick,
mqtt.OnClientExpired,
mqtt.OnRetainedExpired,
mqtt.StoredClients,
mqtt.StoredInflightMessages,
mqtt.StoredRetainedMessages,
mqtt.StoredSubscriptions,
mqtt.StoredSysInfo,
}, []byte{b})
}
// hKey returns a hash set key with a unique prefix.
func (h *Hook) hKey(s string) string {
return h.config.HPrefix + s
}
// Init initializes and connects to the redis service.
func (h *Hook) Init(config any) error {
if _, ok := config.(*Options); !ok && config != nil {
return mqtt.ErrInvalidConfigType
}
h.ctx = context.Background()
if config == nil {
config = &Options{
Options: &redis.Options{
Addr: defaultAddr,
},
}
}
h.config = config.(*Options)
if h.config.HPrefix == "" {
h.config.HPrefix = defaultHPrefix
}
h.Log.Info().
Str("address", h.config.Options.Addr).
Str("username", h.config.Options.Username).
Int("password-len", len(h.config.Options.Password)).
Int("db", h.config.Options.DB).
Msg("connecting to redis service")
h.db = redis.NewClient(h.config.Options)
_, err := h.db.Ping(context.Background()).Result()
if err != nil {
return fmt.Errorf("failed to ping service: %w", err)
}
h.Log.Info().Msg("connected to redis service")
return nil
}
// Close closes the redis connection.
func (h *Hook) Stop() error {
h.Log.Info().Msg("disconnecting from redis service")
return h.db.Close()
}
// OnSessionEstablished adds a client to the store when their session is established.
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// OnWillSent is called when a client sends a will message and the will message is removed
// from the client record.
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
h.updateClient(cl)
}
// updateClient writes the client data to the store.
func (h *Hook) updateClient(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := cl.Properties.Props.Copy(false)
in := &storage.Client{
ID: clientKey(cl),
T: storage.ClientKey,
Remote: cl.Net.Remote,
Listener: cl.Net.Listener,
Username: cl.Properties.Username,
Clean: cl.Properties.Clean,
ProtocolVersion: cl.Properties.ProtocolVersion,
Properties: storage.ClientProperties{
SessionExpiryInterval: props.SessionExpiryInterval,
AuthenticationMethod: props.AuthenticationMethod,
AuthenticationData: props.AuthenticationData,
RequestProblemInfo: props.RequestProblemInfo,
RequestResponseInfo: props.RequestResponseInfo,
ReceiveMaximum: props.ReceiveMaximum,
TopicAliasMaximum: props.TopicAliasMaximum,
User: props.User,
MaximumPacketSize: props.MaximumPacketSize,
},
Will: storage.ClientWill(cl.Properties.Will),
}
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
}
}
// OnDisconnect removes a client from the store if they were using a clean session.
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if !expire {
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
}
}
// OnSubscribed adds one or more client subscriptions to the store.
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
var in *storage.Subscription
for i := 0; i < len(pk.Filters); i++ {
in = &storage.Subscription{
ID: subscriptionKey(cl, pk.Filters[i].Filter),
T: storage.SubscriptionKey,
Client: cl.ID,
Qos: reasonCodes[i],
Filter: pk.Filters[i].Filter,
Identifier: pk.Filters[i].Identifier,
NoLocal: pk.Filters[i].NoLocal,
RetainHandling: pk.Filters[i].RetainHandling,
RetainAsPublished: pk.Filters[i].RetainAsPublished,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
}
}
}
// OnUnsubscribed removes one or more client subscriptions from the store.
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
for i := 0; i < len(pk.Filters); i++ {
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
}
}
}
// OnRetainMessage adds a retained message for a topic to the store.
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
if r == -1 {
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
}
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: retainedKey(pk.TopicName),
T: storage.RetainedKey,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Created: pk.Created,
Origin: pk.Origin,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
}
}
// OnQosPublish adds or updates an inflight message in the store.
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
props := pk.Properties.Copy(false)
in := &storage.Message{
ID: inflightKey(cl, pk),
T: storage.InflightKey,
Origin: pk.Origin,
FixedHeader: pk.FixedHeader,
TopicName: pk.TopicName,
Payload: pk.Payload,
Sent: sent,
Created: pk.Created,
Properties: storage.MessageProperties{
PayloadFormat: props.PayloadFormat,
MessageExpiryInterval: props.MessageExpiryInterval,
ContentType: props.ContentType,
ResponseTopic: props.ResponseTopic,
CorrelationData: props.CorrelationData,
SubscriptionIdentifier: props.SubscriptionIdentifier,
TopicAlias: props.TopicAlias,
User: props.User,
},
}
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
}
}
// OnQosComplete removes a resolved inflight message from the store.
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
}
}
// OnQosDropped removes a dropped inflight message from the store.
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
}
h.OnQosComplete(cl, pk)
}
// OnSysInfoTick stores the latest system info in the store.
func (h *Hook) OnSysInfoTick(sys *system.Info) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
in := &storage.SystemInfo{
ID: sysInfoKey(),
T: storage.SysInfoKey,
Info: *sys,
}
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
if err != nil {
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
}
}
// OnRetainedExpired deletes expired retained messages from the store.
func (h *Hook) OnRetainedExpired(filter string) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
}
}
// OnClientExpired deleted expired clients from the store.
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
if err != nil {
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
}
}
// StoredClients returns all stored clients from the store.
func (h *Hook) StoredClients() (v []storage.Client, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
return
}
for _, row := range rows {
var d storage.Client
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
}
v = append(v, d)
}
return v, nil
}
// StoredSubscriptions returns all stored subscriptions from the store.
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
return
}
for _, row := range rows {
var d storage.Subscription
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
}
v = append(v, d)
}
return v, nil
}
// StoredRetainedMessages returns all stored retained messages from the store.
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
}
v = append(v, d)
}
return v, nil
}
// StoredInflightMessages returns all stored inflight messages from the store.
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
if err != nil && !errors.Is(err, redis.Nil) {
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
return
}
for _, row := range rows {
var d storage.Message
if err = d.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
}
v = append(v, d)
}
return v, nil
}
// StoredSysInfo returns the system info from the store.
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.db == nil {
h.Log.Error().Err(storage.ErrDBFileNotOpen)
return
}
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
if err != nil && !errors.Is(err, redis.Nil) {
return
}
if err = v.UnmarshalBinary([]byte(row)); err != nil {
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
}
return v, nil
}

View File

@ -0,0 +1,789 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package redis
import (
"os"
"sort"
"testing"
"time"
"github.com/mochi-co/mqtt/v2"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
miniredis "github.com/alicebob/miniredis/v2"
redis "github.com/go-redis/redis/v8"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
client = &mqtt.Client{
ID: "test",
Net: mqtt.ClientConnection{
Remote: "test.addr",
Listener: "listener",
},
Properties: mqtt.ClientProperties{
Username: []byte("username"),
Clean: false,
},
}
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
)
func newHook(t *testing.T, addr string) *Hook {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: addr,
},
})
require.NoError(t, err)
return h
}
func teardown(t *testing.T, h *Hook) {
if h.db != nil {
err := h.db.FlushAll(h.ctx).Err()
require.NoError(t, err)
h.Stop()
}
}
func TestClientKey(t *testing.T) {
k := clientKey(&mqtt.Client{ID: "cl1"})
require.Equal(t, "cl1", k)
}
func TestSubscriptionKey(t *testing.T) {
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
require.Equal(t, "cl1:a/b/c", k)
}
func TestRetainedKey(t *testing.T) {
k := retainedKey("a/b/c")
require.Equal(t, "a/b/c", k)
}
func TestInflightKey(t *testing.T) {
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
require.Equal(t, "cl1:1", k)
}
func TestSysInfoKey(t *testing.T) {
require.Equal(t, storage.SysInfoKey, sysInfoKey())
}
func TestID(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.Equal(t, "redis-db", h.ID())
}
func TestProvides(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
require.True(t, h.Provides(mqtt.OnSessionEstablished))
require.True(t, h.Provides(mqtt.OnDisconnect))
require.True(t, h.Provides(mqtt.OnSubscribed))
require.True(t, h.Provides(mqtt.OnUnsubscribed))
require.True(t, h.Provides(mqtt.OnRetainMessage))
require.True(t, h.Provides(mqtt.OnQosPublish))
require.True(t, h.Provides(mqtt.OnQosComplete))
require.True(t, h.Provides(mqtt.OnQosDropped))
require.True(t, h.Provides(mqtt.OnSysInfoTick))
require.True(t, h.Provides(mqtt.StoredClients))
require.True(t, h.Provides(mqtt.StoredInflightMessages))
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
require.True(t, h.Provides(mqtt.StoredSubscriptions))
require.True(t, h.Provides(mqtt.StoredSysInfo))
require.False(t, h.Provides(mqtt.OnACLCheck))
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
}
func TestHKey(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.SetOpts(&logger, nil)
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
}
func TestInitUseDefaults(t *testing.T) {
s := miniredis.RunT(t)
s.StartAddr(defaultAddr)
defer s.Close()
h := newHook(t, defaultAddr)
h.SetOpts(&logger, nil)
err := h.Init(nil)
require.NoError(t, err)
defer teardown(t, h)
require.Equal(t, defaultHPrefix, h.config.HPrefix)
require.Equal(t, defaultAddr, h.config.Options.Addr)
}
func TestInitBadConfig(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(map[string]any{})
require.Error(t, err)
}
func TestInitBadAddr(t *testing.T) {
h := new(Hook)
h.SetOpts(&logger, nil)
err := h.Init(&Options{
Options: &redis.Options{
Addr: "abc:123",
},
})
require.Error(t, err)
}
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
require.Equal(t, client.Net.Remote, r.Remote)
require.Equal(t, client.Net.Listener, r.Listener)
require.Equal(t, client.Properties.Username, r.Username)
require.Equal(t, client.Properties.Clean, r.Clean)
require.NotSame(t, client, r)
h.OnDisconnect(client, nil, false)
r2 := new(storage.Client)
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r2.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.ID)
h.OnDisconnect(client, nil, true)
r3 := new(storage.Client)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
require.Empty(t, r3.ID)
}
func TestOnSessionEstablishedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnSessionEstablishedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSessionEstablished(client, packets.Packet{})
}
func TestOnWillSent(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
c1 := client
c1.Properties.Will.Flag = 1
h.OnWillSent(c1, packets.Packet{})
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, uint32(1), r.Will.Flag)
require.NotSame(t, client, r)
}
func TestOnClientExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
cl := &mqtt.Client{ID: "cl1"}
clientKey := clientKey(cl)
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
require.NoError(t, err)
r := new(storage.Client)
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, clientKey, r.ID)
h.OnClientExpired(cl)
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
require.Error(t, err)
require.ErrorIs(t, redis.Nil, err)
}
func TestOnClientExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnClientExpired(client)
}
func TestOnClientExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnClientExpired(client)
}
func TestOnDisconnectNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnDisconnect(client, nil, false)
}
func TestOnDisconnectClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnDisconnect(client, nil, false)
}
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
r := new(storage.Subscription)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, client.ID, r.Client)
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
require.Equal(t, byte(0), r.Qos)
h.OnUnsubscribed(client, pkf)
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnSubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnSubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSubscribed(client, pkf, []byte{0})
}
func TestOnUnsubscribedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnUnsubscribed(client, pkf)
}
func TestOnUnsubscribedClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnUnsubscribed(client, pkf)
}
func TestOnRetainMessageThenUnset(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnRetainMessage(client, pk, 1)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
// coverage: delete deleted
h.OnRetainMessage(client, pk, -1)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpired(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
m := &storage.Message{
ID: retainedKey("a/b/c"),
T: storage.RetainedKey,
TopicName: "a/b/c",
}
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
require.NoError(t, err)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, m.TopicName, r.TopicName)
h.OnRetainedExpired(m.TopicName)
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnRetainedExpiredClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainedExpiredNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainedExpired("a/b/c")
}
func TestOnRetainMessageNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnRetainMessageClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnRetainMessage(client, packets.Packet{}, 0)
}
func TestOnQosPublishThenQOSComplete(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
pk := packets.Packet{
FixedHeader: packets.FixedHeader{
Retain: true,
Qos: 2,
},
Payload: []byte("hello"),
TopicName: "a/b/c",
}
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
r := new(storage.Message)
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, pk.TopicName, r.TopicName)
require.Equal(t, pk.Payload, r.Payload)
// ensure dates are properly saved to bolt
require.True(t, r.Sent > 0)
require.True(t, time.Now().Unix()-1 < r.Sent)
// OnQosDropped is a passthrough to OnQosComplete here
h.OnQosDropped(client, pk)
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
require.Error(t, err)
require.ErrorIs(t, err, redis.Nil)
}
func TestOnQosPublishNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosPublishClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
}
func TestOnQosCompleteNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosCompleteClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnQosComplete(client, packets.Packet{})
}
func TestOnQosDroppedNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnQosDropped(client, packets.Packet{})
}
func TestOnSysInfoTick(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
info := &system.Info{
Version: "2.0.0",
BytesReceived: 100,
}
h.OnSysInfoTick(info)
r := new(storage.SystemInfo)
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
require.NoError(t, err)
err = r.UnmarshalBinary([]byte(row))
require.NoError(t, err)
require.Equal(t, info.Version, r.Version)
require.Equal(t, info.BytesReceived, r.BytesReceived)
require.NotSame(t, info, r)
}
func TestOnSysInfoTickClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
h.OnSysInfoTick(new(system.Info))
}
func TestOnSysInfoTickNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
h.OnSysInfoTick(new(system.Info))
}
func TestStoredClients(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with clients
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
require.NoError(t, err)
r, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "cl1", r[0].ID)
require.Equal(t, "cl2", r[1].ID)
require.Equal(t, "cl3", r[2].ID)
}
func TestStoredClientsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredClients()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredClientsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredClients()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSubscriptions(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with subscriptions
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
require.NoError(t, err)
r, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "sub1", r[0].ID)
require.Equal(t, "sub2", r[1].ID)
require.Equal(t, "sub3", r[2].ID)
}
func TestStoredSubscriptionsNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSubscriptionsClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSubscriptions()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredRetainedMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
r, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "m1", r[0].ID)
require.Equal(t, "m2", r[1].ID)
require.Equal(t, "m3", r[2].ID)
}
func TestStoredRetainedMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredRetainedMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredInflightMessages(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with messages
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
require.NoError(t, err)
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
require.NoError(t, err)
r, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, r, 3)
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
require.Equal(t, "i1", r[0].ID)
require.Equal(t, "i2", r[1].ID)
require.Equal(t, "i3", r[2].ID)
}
func TestStoredInflightMessagesNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredInflightMessagesClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredInflightMessages()
require.Empty(t, v)
require.Error(t, err)
}
func TestStoredSysInfo(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
defer teardown(t, h)
// populate with sys info
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
&storage.SystemInfo{
ID: storage.SysInfoKey,
Info: system.Info{
Version: "2.0.0",
},
T: storage.SysInfoKey,
}).Err()
require.NoError(t, err)
r, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", r.Info.Version)
}
func TestStoredSysInfoNoDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
h.db = nil
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.NoError(t, err)
}
func TestStoredSysInfoClosedDB(t *testing.T) {
s := miniredis.RunT(t)
defer s.Close()
h := newHook(t, s.Addr())
teardown(t, h)
v, err := h.StoredSysInfo()
require.Empty(t, v)
require.Error(t, err)
}

View File

@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"encoding/json"
"errors"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
)
const (
SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store
SysInfoKey = "SYS" // unique key to denote server system information in a store
RetainedKey = "RET" // unique key to denote retained messages in a store
InflightKey = "IFM" // unique key to denote inflight messages in a store
ClientKey = "CL" // unique key to denote clients in a store
)
var (
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
ErrDBFileNotOpen = errors.New("db file not open")
)
// Client is a storable representation of an mqtt client.
type Client struct {
Will ClientWill `json:"will"` // will topic and payload data if applicable
Properties ClientProperties `json:"properties"` // the connect properties for the client
Username []byte `json:"username"` // the username of the client
ID string `json:"id" storm:"id"` // the client id / storage key
T string `json:"t"` // the data type (client)
Remote string `json:"remote"` // the remote address of the client
Listener string `json:"listener"` // the listener the client connected on
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
Clean bool `json:"clean"` // if the client requested a clean start/session
}
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
type ClientProperties struct {
AuthenticationData []byte `json:"authenticationData"`
User []packets.UserProperty `json:"user"`
AuthenticationMethod string `json:"authenticationMethod"`
SessionExpiryInterval uint32 `json:"sessionExpiryInterval"`
MaximumPacketSize uint32 `json:"maximumPacketSize"`
ReceiveMaximum uint16 `json:"receiveMaximum"`
TopicAliasMaximum uint16 `json:"topicAliasMaximum"`
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag"`
RequestProblemInfo byte `json:"requestProblemInfo"`
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag"`
RequestResponseInfo byte `json:"requestResponseInfo"`
}
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
type ClientWill struct {
Payload []byte `json:"payload"`
User []packets.UserProperty `json:"user"`
TopicName string `json:"topicName"`
Flag uint32 `json:"flag"`
WillDelayInterval uint32 `json:"willDelayInterval"`
Qos byte `json:"qos"`
Retain bool `json:"retain"`
}
// MarshalBinary encodes the values into a json string.
func (d Client) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Client) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Message is a storable representation of an MQTT message (specifically publish).
type Message struct {
Properties MessageProperties `json:"properties"` // -
Payload []byte `json:"payload"` // the message payload (if retained)
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
Origin string `json:"origin"` // the id of the client who sent the message
TopicName string `json:"topic_name"` // the topic the message was sent to (if retained)
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
Created int64 `json:"created"` // the time the message was created in unixtime
Sent int64 `json:"sent"` // the last time the message was sent (for retries) in unixtime (if inflight)
PacketID uint16 `json:"packet_id"` // the unique id of the packet (if inflight)
}
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
type MessageProperties struct {
CorrelationData []byte `json:"correlationData"`
SubscriptionIdentifier []int `json:"subscriptionIdentifier"`
User []packets.UserProperty `json:"user"`
ContentType string `json:"contentType"`
ResponseTopic string `json:"responseTopic"`
MessageExpiryInterval uint32 `json:"messageExpiry"`
TopicAlias uint16 `json:"topicAlias"`
PayloadFormat byte `json:"payloadFormat"`
PayloadFormatFlag bool `json:"payloadFormatFlag"`
}
// MarshalBinary encodes the values into a json string.
func (d Message) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Message) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// Subscription is a storable representation of an mqtt subscription.
type Subscription struct {
T string `json:"t"`
ID string `json:"id" storm:"id"`
Client string `json:"client"`
Filter string `json:"filter"`
Identifier int `json:"identifier"`
RetainHandling byte `json:"retain_handling"`
Qos byte `json:"qos"`
RetainAsPublished bool `json:"retain_as_pub"`
NoLocal bool `json:"no_local"`
}
// MarshalBinary encodes the values into a json string.
func (d Subscription) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *Subscription) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}
// SystemInfo is a storable representation of the system information values.
type SystemInfo struct {
system.Info // embed the system info struct
T string `json:"t"` // the data type
ID string `json:"id" storm:"id"` // the storage key
}
// MarshalBinary encodes the values into a json string.
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
return json.Marshal(d)
}
// UnmarshalBinary decodes a json string into a struct.
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
if len(data) == 0 {
return nil
}
return json.Unmarshal(data, d)
}

View File

@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package storage
import (
"testing"
"time"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
var (
clientStruct = Client{
ID: "test",
T: "client",
Remote: "remote",
Listener: "listener",
Username: []byte("mochi"),
Clean: true,
Properties: ClientProperties{
SessionExpiryInterval: 2,
SessionExpiryIntervalFlag: true,
AuthenticationMethod: "a",
AuthenticationData: []byte("test"),
RequestProblemInfo: 1,
RequestProblemInfoFlag: true,
RequestResponseInfo: 1,
ReceiveMaximum: 128,
TopicAliasMaximum: 256,
User: []packets.UserProperty{
{Key: "k", Val: "v"},
},
MaximumPacketSize: 120,
},
Will: ClientWill{
Qos: 1,
Payload: []byte("abc"),
TopicName: "a/b/c",
Flag: 1,
Retain: true,
WillDelayInterval: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
}
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
messageStruct = Message{
T: "message",
Payload: []byte("payload"),
FixedHeader: packets.FixedHeader{
Remaining: 2,
Type: 3,
Qos: 1,
Dup: true,
Retain: true,
},
ID: "id",
Origin: "mochi",
TopicName: "topic",
Properties: MessageProperties{
PayloadFormat: 1,
PayloadFormatFlag: true,
MessageExpiryInterval: 20,
ContentType: "type",
ResponseTopic: "a/b/r",
CorrelationData: []byte("r"),
SubscriptionIdentifier: []int{1},
TopicAlias: 2,
User: []packets.UserProperty{
{Key: "k2", Val: "v2"},
},
},
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
PacketID: 100,
}
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
subscriptionStruct = Subscription{
T: "subscription",
ID: "id",
Client: "mochi",
Filter: "a/b/c",
Qos: 1,
}
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","identifier":0,"retain_handling":0,"qos":1,"retain_as_pub":false,"no_local":false}`)
sysInfoStruct = SystemInfo{
T: "info",
ID: "id",
Info: system.Info{
Version: "2.0.0",
Started: 1,
Uptime: 2,
BytesReceived: 3,
BytesSent: 4,
ClientsConnected: 5,
ClientsMaximum: 7,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
PacketsReceived: 12,
PacketsSent: 13,
Retained: 15,
Inflight: 16,
InflightDropped: 17,
},
}
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
)
func TestClientMarshalBinary(t *testing.T) {
data, err := clientStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, clientJSON, data)
}
func TestClientUnmarshalBinary(t *testing.T) {
d := clientStruct
err := d.UnmarshalBinary(clientJSON)
require.NoError(t, err)
require.Equal(t, clientStruct, d)
}
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
d := Client{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Client{}, d)
}
func TestMessageMarshalBinary(t *testing.T) {
data, err := messageStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, messageJSON, data)
}
func TestMessageUnmarshalBinary(t *testing.T) {
d := messageStruct
err := d.UnmarshalBinary(messageJSON)
require.NoError(t, err)
require.Equal(t, messageStruct, d)
}
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
d := Message{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Message{}, d)
}
func TestSubscriptionMarshalBinary(t *testing.T) {
data, err := subscriptionStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, subscriptionJSON, data)
}
func TestSubscriptionUnmarshalBinary(t *testing.T) {
d := subscriptionStruct
err := d.UnmarshalBinary(subscriptionJSON)
require.NoError(t, err)
require.Equal(t, subscriptionStruct, d)
}
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
d := Subscription{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, Subscription{}, d)
}
func TestSysInfoMarshalBinary(t *testing.T) {
data, err := sysInfoStruct.MarshalBinary()
require.NoError(t, err)
require.Equal(t, sysInfoJSON, data)
}
func TestSysInfoUnmarshalBinary(t *testing.T) {
d := sysInfoStruct
err := d.UnmarshalBinary(sysInfoJSON)
require.NoError(t, err)
require.Equal(t, sysInfoStruct, d)
}
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
d := SystemInfo{}
err := d.UnmarshalBinary([]byte{})
require.NoError(t, err)
require.Equal(t, SystemInfo{}, d)
}

View File

@ -0,0 +1,634 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"errors"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/hooks/storage"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
type modifiedHookBase struct {
HookBase
err error
fail bool
failAt int
}
var errTestHook = errors.New("error")
func (h *modifiedHookBase) ID() string {
return "modified"
}
func (h *modifiedHookBase) Init(config any) error {
if config != nil {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) Provides(b byte) bool {
return true
}
func (h *modifiedHookBase) Stop() error {
if h.fail {
return errTestHook
}
return nil
}
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
return true
}
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
return true
}
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
if h.fail {
if h.err != nil {
return pk, h.err
}
return pk, errTestHook
}
return pk, nil
}
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
if h.fail {
return will, errTestHook
}
return will, nil
}
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
if h.fail || h.failAt == 1 {
return v, errTestHook
}
return []storage.Client{
{ID: "cl1"},
{ID: "cl2"},
{ID: "cl3"},
}, nil
}
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
if h.fail || h.failAt == 2 {
return v, errTestHook
}
return []storage.Subscription{
{ID: "sub1"},
{ID: "sub2"},
{ID: "sub3"},
}, nil
}
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 3 {
return v, errTestHook
}
return []storage.Message{
{ID: "r1"},
{ID: "r2"},
{ID: "r3"},
}, nil
}
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
if h.fail || h.failAt == 4 {
return v, errTestHook
}
return []storage.Message{
{ID: "i1"},
{ID: "i2"},
{ID: "i3"},
}, nil
}
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
if h.fail || h.failAt == 5 {
return v, errTestHook
}
return storage.SystemInfo{
Info: system.Info{
Version: "2.0.0",
},
}, nil
}
type providesCheckHook struct {
HookBase
}
func (h *providesCheckHook) Provides(b byte) bool {
return b == OnConnect
}
func TestHooksProvides(t *testing.T) {
h := new(Hooks)
err := h.Add(new(providesCheckHook), nil)
require.NoError(t, err)
err = h.Add(new(HookBase), nil)
require.NoError(t, err)
require.True(t, h.Provides(OnConnect, OnDisconnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHooksAddLenGetAll(t *testing.T) {
h := new(Hooks)
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
err = h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(2), h.Len())
all := h.GetAll()
require.Equal(t, "base", all[0].ID())
require.Equal(t, "modified", all[1].ID())
}
func TestHooksAddInitFailure(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), map[string]any{})
require.Error(t, err)
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
}
func TestHooksStop(t *testing.T) {
h := new(Hooks)
h.Log = &logger
err := h.Add(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
require.Equal(t, int64(1), h.Len())
h.Stop()
}
// coverage: also cover some empty functions
func TestHooksNonReturns(t *testing.T) {
h := new(Hooks)
cl := new(Client)
for i := 0; i < 2; i++ {
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
// on first iteration, check without hook methods
h.OnStarted()
h.OnStopped()
h.OnSysInfoTick(new(system.Info))
h.OnConnect(cl, packets.Packet{})
h.OnSessionEstablished(cl, packets.Packet{})
h.OnDisconnect(cl, nil, false)
h.OnPacketSent(cl, packets.Packet{}, []byte{})
h.OnPacketProcessed(cl, packets.Packet{}, nil)
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
h.OnUnsubscribed(cl, packets.Packet{})
h.OnPublished(cl, packets.Packet{})
h.OnPublishDropped(cl, packets.Packet{})
h.OnRetainMessage(cl, packets.Packet{}, 0)
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
h.OnQosComplete(cl, packets.Packet{})
h.OnQosDropped(cl, packets.Packet{})
h.OnWillSent(cl, packets.Packet{})
h.OnClientExpired(cl)
h.OnRetainedExpired("a/b/c")
// on second iteration, check added hook methods
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
})
}
}
func TestHooksOnConnectAuthenticate(t *testing.T) {
h := new(Hooks)
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.True(t, ok)
}
func TestHooksOnACLCheck(t *testing.T) {
h := new(Hooks)
ok := h.OnACLCheck(new(Client), "a/b/c", true)
require.False(t, ok)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
ok = h.OnACLCheck(new(Client), "a/b/c", true)
require.True(t, ok)
}
func TestHooksOnSubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnSubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnSelectSubscribers(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
subs := &Subscribers{
Subscriptions: map[string]packets.Subscription{
"cl1": {Filter: "a/b/c"},
},
}
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
require.EqualValues(t, subs, subs2)
}
func TestHooksOnUnsubscribe(t *testing.T) {
h := new(Hooks)
err := h.Add(new(modifiedHookBase), nil)
require.NoError(t, err)
pki := packets.Packet{
Filters: packets.Subscriptions{
{Filter: "a/b/c", Qos: 1},
},
}
pk := h.OnUnsubscribe(new(Client), pki)
require.EqualValues(t, pk, pki)
}
func TestHooksOnPublish(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketRead(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: failure
hook.fail = true
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
// coverage: reject packet
hook.err = packets.ErrRejectPacket
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrRejectPacket)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnAuthPacket(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
hook.fail = true
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.Error(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnPacketEncode(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHooksOnLWT(t *testing.T) {
h := new(Hooks)
h.Log = &logger
hook := new(modifiedHookBase)
err := h.Add(hook, nil)
require.NoError(t, err)
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
// coverage: fail lwt
hook.fail = true
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHooksStoredClients(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredClients()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredClients()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSubscriptions(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSubscriptions()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredSubscriptions()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredRetainedMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredRetainedMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredRetainedMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredInflightMessages(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 0)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredInflightMessages()
require.NoError(t, err)
require.Len(t, v, 3)
hook.fail = true
v, err = h.StoredInflightMessages()
require.Error(t, err)
require.Len(t, v, 0)
}
func TestHooksStoredSysInfo(t *testing.T) {
h := new(Hooks)
h.Log = &logger
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Info.Version)
hook := new(modifiedHookBase)
err = h.Add(hook, nil)
require.NoError(t, err)
v, err = h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "2.0.0", v.Info.Version)
hook.fail = true
v, err = h.StoredSysInfo()
require.Error(t, err)
require.Equal(t, "", v.Info.Version)
}
func TestHookBaseID(t *testing.T) {
h := new(HookBase)
require.Equal(t, "base", h.ID())
}
func TestHookBaseProvidesNone(t *testing.T) {
h := new(HookBase)
require.False(t, h.Provides(OnConnect))
require.False(t, h.Provides(OnDisconnect))
}
func TestHookBaseInit(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Init(nil))
}
func TestHookBaseSetOpts(t *testing.T) {
h := new(HookBase)
h.SetOpts(&logger, new(HookOptions))
require.NotNil(t, h.Log)
require.NotNil(t, h.Opts)
}
func TestHookBaseClose(t *testing.T) {
h := new(HookBase)
require.Nil(t, h.Stop())
}
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
h := new(HookBase)
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
require.False(t, v)
}
func TestHookBaseOnACLCheck(t *testing.T) {
h := new(HookBase)
v := h.OnACLCheck(new(Client), "topic", true)
require.False(t, v)
}
func TestHookBaseOnPublish(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnPacketRead(t *testing.T) {
h := new(HookBase)
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnAuthPacket(t *testing.T) {
h := new(HookBase)
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
require.NoError(t, err)
require.Equal(t, uint16(10), pk.PacketID)
}
func TestHookBaseOnLWT(t *testing.T) {
h := new(HookBase)
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
require.NoError(t, err)
require.Equal(t, "a/b/c", lwt.TopicName)
}
func TestHookBaseStoredClients(t *testing.T) {
h := new(HookBase)
v, err := h.StoredClients()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredSubscriptions(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSubscriptions()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredInflightMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredInflightMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoredRetainedMessages(t *testing.T) {
h := new(HookBase)
v, err := h.StoredRetainedMessages()
require.NoError(t, err)
require.Empty(t, v)
}
func TestHookBaseStoreSysInfo(t *testing.T) {
h := new(HookBase)
v, err := h.StoredSysInfo()
require.NoError(t, err)
require.Equal(t, "", v.Version)
}

View File

@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sort"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/packets"
)
// Inflight is a map of InflightMessage keyed on packet id.
type Inflight struct {
sync.RWMutex
internal map[uint16]packets.Packet // internal contains the inflight packets
receiveQuota int32 // remaining inbound qos quota for flow control
sendQuota int32 // remaining outbound qos quota for flow control
maximumReceiveQuota int32 // maximum allowed receive quota
maximumSendQuota int32 // maximum allowed send quota
}
// NewInflights returns a new instance of an Inflight packets map.
func NewInflights() *Inflight {
return &Inflight{
internal: map[uint16]packets.Packet{},
}
}
// Set adds or updates an inflight packet by packet id.
func (i *Inflight) Set(m packets.Packet) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[m.PacketID]
i.internal[m.PacketID] = m
return !ok
}
// Get returns an inflight packet by packet id.
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
if m, ok := i.internal[id]; ok {
return m, true
}
return packets.Packet{}, false
}
// Len returns the size of the inflight messages map.
func (i *Inflight) Len() int {
i.RLock()
defer i.RUnlock()
return len(i.internal)
}
// Clone returns a new instance of Inflight with the same message data.
// This is used when transferring inflights from a taken-over session.
func (i *Inflight) Clone() *Inflight {
c := NewInflights()
i.RLock()
defer i.RUnlock()
for k, v := range i.internal {
c.internal[k] = v
}
return c
}
// GetAll returns all the inflight messages.
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
i.RLock()
defer i.RUnlock()
m := []packets.Packet{}
for _, v := range i.internal {
if !immediate || (immediate && v.Expiry < 0) {
m = append(m, v)
}
}
sort.Slice(m, func(i, j int) bool {
return uint16(m[i].Created) < uint16(m[j].Created)
})
return m
}
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
// is free to continue sending.
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
i.RLock()
defer i.RUnlock()
m := i.GetAll(true)
if len(m) > 0 {
return m[0], true
}
return packets.Packet{}, false
}
// Delete removes an in-flight message from the map. Returns true if the message existed.
func (i *Inflight) Delete(id uint16) bool {
i.Lock()
defer i.Unlock()
_, ok := i.internal[id]
delete(i.internal, id)
return ok
}
// TakeRecieveQuota reduces the receive quota by 1.
func (i *Inflight) DecreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) > 0 {
atomic.AddInt32(&i.receiveQuota, -1)
}
}
// TakeRecieveQuota increases the receive quota by 1.
func (i *Inflight) IncreaseReceiveQuota() {
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
atomic.AddInt32(&i.receiveQuota, 1)
}
}
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
func (i *Inflight) ResetReceiveQuota(n int32) {
atomic.StoreInt32(&i.receiveQuota, n)
atomic.StoreInt32(&i.maximumReceiveQuota, n)
}
// DecreaseSendQuota reduces the send quota by 1.
func (i *Inflight) DecreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) > 0 {
atomic.AddInt32(&i.sendQuota, -1)
}
}
// IncreaseSendQuota increases the send quota by 1.
func (i *Inflight) IncreaseSendQuota() {
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
atomic.AddInt32(&i.sendQuota, 1)
}
}
// ResetSendQuota resets the send quota to the maximum allowed value.
func (i *Inflight) ResetSendQuota(n int32) {
atomic.StoreInt32(&i.sendQuota, n)
atomic.StoreInt32(&i.maximumSendQuota, n)
}

View File

@ -0,0 +1,199 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"sync/atomic"
"testing"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
func TestInflightSet(t *testing.T) {
cl, _, _ := newTestClient()
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.True(t, r)
require.NotNil(t, cl.State.Inflight.internal[1])
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
require.False(t, r)
}
func TestInflightGet(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
msg, ok := cl.State.Inflight.Get(2)
require.True(t, ok)
require.NotEqual(t, 0, msg.PacketID)
}
func TestInflightGetAllAndImmediate(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
require.Equal(t, []packets.Packet{
{PacketID: 1, Created: 1},
{PacketID: 2, Created: 2},
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
{PacketID: 5, Created: 5},
}, cl.State.Inflight.GetAll(false))
require.Equal(t, []packets.Packet{
{PacketID: 3, Created: 3, Expiry: -1},
{PacketID: 4, Created: 4, Expiry: -1},
}, cl.State.Inflight.GetAll(true))
}
func TestInflightLen(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
}
func TestInflightClone(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
require.Equal(t, 1, cl.State.Inflight.Len())
cloned := cl.State.Inflight.Clone()
require.NotNil(t, cloned)
require.NotSame(t, cloned, cl.State.Inflight)
}
func TestInflightDelete(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
require.NotNil(t, cl.State.Inflight.internal[3])
r := cl.State.Inflight.Delete(3)
require.True(t, r)
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
_, ok := cl.State.Inflight.Get(3)
require.False(t, ok)
r = cl.State.Inflight.Delete(3)
require.False(t, r)
}
func TestResetReceiveQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
i.ResetReceiveQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
}
func TestReceiveQuota(t *testing.T) {
i := NewInflights()
i.receiveQuota = 4
i.maximumReceiveQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
// Return 1
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Try to go over max limit
i.IncreaseReceiveQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
// Reset to max 1
i.ResetReceiveQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
// Take 1
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
// Try to go below zero
i.DecreaseReceiveQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
}
func TestResetSendQuota(t *testing.T) {
i := NewInflights()
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
i.ResetSendQuota(6)
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
}
func TestSendQuota(t *testing.T) {
i := NewInflights()
i.sendQuota = 4
i.maximumSendQuota = 5
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
// Return 1
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Try to go over max limit
i.IncreaseSendQuota()
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
// Reset to max 1
i.ResetSendQuota(1)
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
// Take 1
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
// Try to go below zero
i.DecreaseSendQuota()
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
}
func TestNextImmediate(t *testing.T) {
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
pk, ok := cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
r := cl.State.Inflight.Delete(3)
require.True(t, r)
pk, ok = cl.State.Inflight.NextImmediate()
require.True(t, ok)
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
r = cl.State.Inflight.Delete(4)
require.True(t, r)
_, ok = cl.State.Inflight.NextImmediate()
require.False(t, ok)
}

View File

@ -0,0 +1,118 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"encoding/json"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/mochi-co/mqtt/v2/system"
"github.com/rs/zerolog"
)
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
type HTTPStats struct {
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // the http server
log *zerolog.Logger // server logger
sysInfo *system.Info // pointers to the server data
end uint32 // ensure the close methods are only called once
}
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats {
if config == nil {
config = new(Config)
}
return &HTTPStats{
id: id,
address: address,
sysInfo: sysInfo,
config: config,
}
}
// ID returns the id of the listener.
func (l *HTTPStats) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *HTTPStats) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *HTTPStats) Protocol() string {
if l.listen != nil && l.listen.TLSConfig != nil {
return "https"
}
return "http"
}
// Init initializes the listener.
func (l *HTTPStats) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/", l.jsonHandler)
l.listen = &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
Addr: l.address,
Handler: mux,
}
if l.config.TLSConfig != nil {
l.listen.TLSConfig = l.config.TLSConfig
}
return nil
}
// Serve starts listening for new connections and serving responses.
func (l *HTTPStats) Serve(establish EstablishFn) {
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *HTTPStats) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
info := *l.sysInfo.Clone()
out, err := json.MarshalIndent(info, "", "\t")
if err != nil {
io.WriteString(w, err.Error())
}
w.Write(out)
}

View File

@ -0,0 +1,127 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/mochi-co/mqtt/v2/system"
"github.com/stretchr/testify/require"
)
func TestNewHTTPStats(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestHTTPStatsID(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "t1", l.ID())
}
func TestHTTPStatsAddress(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, testAddr, l.Address())
}
func TestHTTPStatsProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, nil, nil)
require.Equal(t, "http", l.Protocol())
}
func TestHTTPStatsTLSProtocol(t *testing.T) {
l := NewHTTPStats("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
}, nil)
l.Init(nil)
require.Equal(t, "https", l.Protocol())
}
func TestHTTPStatsInit(t *testing.T) {
sysInfo := new(system.Info)
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.sysInfo)
require.Equal(t, sysInfo, l.sysInfo)
require.NotNil(t, l.listen)
require.Equal(t, testAddr, l.listen.Addr)
}
func TestHTTPStatsServeAndClose(t *testing.T) {
sysInfo := &system.Info{
Version: "test",
}
// setup http stats listener
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
// get body from stats address
resp, err := http.Get("http://localhost" + testAddr)
require.NoError(t, err)
require.NotNil(t, resp)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// decode body from json and check data
v := new(system.Info)
err = json.Unmarshal(body, v)
require.NoError(t, err)
require.Equal(t, "test", v.Version)
// ensure listening is closed
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
_, err = http.Get("http://localhost" + testAddr)
require.Error(t, err)
<-o
}
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
sysInfo := &system.Info{
Version: "test",
}
l := NewHTTPStats("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
}, sysInfo)
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
l.Close(MockCloser)
}

View File

@ -0,0 +1,135 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"net"
"sync"
"github.com/rs/zerolog"
)
// Config contains configuration values for a listener.
type Config struct {
// TLSConfig is a tls.Config configuration to be used with the listener.
// See examples folder for basic and mutual-tls use.
TLSConfig *tls.Config
}
// EstablishFn is a callback function for establishing new clients.
type EstablishFn func(id string, c net.Conn) error
// CloseFunc is a callback function for closing all listener clients.
type CloseFn func(id string)
// Listener is an interface for network listeners. A network listener listens
// for incoming client connections and adds them to the server.
type Listener interface {
Init(*zerolog.Logger) error // open the network address
Serve(EstablishFn) // starting actively listening for new connections
ID() string // return the id of the listener
Address() string // the address of the listener
Protocol() string // the protocol in use by the listener
Close(CloseFn) // stop and close the listener
}
// Listeners contains the network listeners for the broker.
type Listeners struct {
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
internal map[string]Listener // a map of active listeners.
sync.RWMutex
}
// New returns a new instance of Listeners.
func New() *Listeners {
return &Listeners{
internal: map[string]Listener{},
}
}
// Add adds a new listener to the listeners map, keyed on id.
func (l *Listeners) Add(val Listener) {
l.Lock()
defer l.Unlock()
l.internal[val.ID()] = val
}
// Get returns the value of a listener if it exists.
func (l *Listeners) Get(id string) (Listener, bool) {
l.RLock()
defer l.RUnlock()
val, ok := l.internal[id]
return val, ok
}
// Len returns the length of the listeners map.
func (l *Listeners) Len() int {
l.RLock()
defer l.RUnlock()
return len(l.internal)
}
// Delete removes a listener from the internal map.
func (l *Listeners) Delete(id string) {
l.Lock()
defer l.Unlock()
delete(l.internal, id)
}
// Serve starts a listener serving from the internal map.
func (l *Listeners) Serve(id string, establisher EstablishFn) {
l.RLock()
defer l.RUnlock()
listener := l.internal[id]
go func(e EstablishFn) {
defer l.wg.Done()
l.wg.Add(1)
listener.Serve(e)
}(establisher)
}
// ServeAll starts all listeners serving from the internal map.
func (l *Listeners) ServeAll(establisher EstablishFn) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Serve(id, establisher)
}
}
// Close stops a listener from the internal map.
func (l *Listeners) Close(id string, closer CloseFn) {
l.RLock()
defer l.RUnlock()
if listener, ok := l.internal[id]; ok {
listener.Close(closer)
}
}
// CloseAll iterates and closes all registered listeners.
func (l *Listeners) CloseAll(closer CloseFn) {
l.RLock()
i := 0
ids := make([]string, len(l.internal))
for id := range l.internal {
ids[i] = id
i++
}
l.RUnlock()
for _, id := range ids {
l.Close(id, closer)
}
l.wg.Wait()
}

View File

@ -0,0 +1,177 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"log"
"os"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
const testAddr = ":22222"
var (
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
-----END CERTIFICATE-----`)
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
-----END RSA PRIVATE KEY-----`)
tlsConfigBasic *tls.Config
)
func init() {
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
if err != nil {
log.Fatal(err)
}
// Basic TLS Config
tlsConfigBasic = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
}
func TestNew(t *testing.T) {
l := New()
require.NotNil(t, l.internal)
}
func TestAddListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
require.Contains(t, l.internal, "t1")
}
func TestGetListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
g, ok := l.Get("t1")
require.True(t, ok)
require.Equal(t, g.ID(), "t1")
}
func TestLenListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
require.Contains(t, l.internal, "t1")
require.Contains(t, l.internal, "t2")
require.Equal(t, 2, l.Len())
}
func TestDeleteListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
require.Contains(t, l.internal, "t1")
l.Delete("t1")
_, ok := l.Get("t1")
require.False(t, ok)
require.Nil(t, l.internal["t1"])
}
func TestServeListener(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
require.False(t, l.internal["t1"].(*MockListener).IsServing())
}
func TestServeAllListeners(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
l.Add(NewMockListener("t3", testAddr))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
require.True(t, l.internal["t2"].(*MockListener).IsServing())
require.True(t, l.internal["t3"].(*MockListener).IsServing())
l.Close("t1", MockCloser)
l.Close("t2", MockCloser)
l.Close("t3", MockCloser)
require.False(t, l.internal["t1"].(*MockListener).IsServing())
require.False(t, l.internal["t2"].(*MockListener).IsServing())
require.False(t, l.internal["t3"].(*MockListener).IsServing())
}
func TestCloseListener(t *testing.T) {
l := New()
mocked := NewMockListener("t1", testAddr)
l.Add(mocked)
l.Serve("t1", MockEstablisher)
time.Sleep(time.Millisecond)
var closed bool
l.Close("t1", func(id string) {
closed = true
})
require.True(t, closed)
}
func TestCloseAllListeners(t *testing.T) {
l := New()
l.Add(NewMockListener("t1", testAddr))
l.Add(NewMockListener("t2", testAddr))
l.Add(NewMockListener("t3", testAddr))
l.ServeAll(MockEstablisher)
time.Sleep(time.Millisecond)
require.True(t, l.internal["t1"].(*MockListener).IsServing())
require.True(t, l.internal["t2"].(*MockListener).IsServing())
require.True(t, l.internal["t3"].(*MockListener).IsServing())
closed := make(map[string]bool)
l.CloseAll(func(id string) {
closed[id] = true
})
require.Contains(t, closed, "t1")
require.Contains(t, closed, "t2")
require.Contains(t, closed, "t3")
require.True(t, closed["t1"])
require.True(t, closed["t2"])
require.True(t, closed["t3"])
}

View File

@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"fmt"
"net"
"sync"
"github.com/rs/zerolog"
)
// MockEstablisher is a function signature which can be used in testing.
func MockEstablisher(id string, c net.Conn) error {
return nil
}
// MockCloser is a function signature which can be used in testing.
func MockCloser(id string) {}
// MockListener is a mock listener for establishing client connections.
type MockListener struct {
sync.RWMutex
id string // the id of the listener
address string // the network address the listener binds to
Config *Config // configuration for the listener
done chan bool // indicate the listener is done
Serving bool // indicate the listener is serving
Listening bool // indiciate the listener is listening
ErrListen bool // throw an error on listen
}
// NewMockListener returns a new instance of MockListener.
func NewMockListener(id, address string) *MockListener {
return &MockListener{
id: id,
address: address,
done: make(chan bool),
}
}
// Serve serves the mock listener.
func (l *MockListener) Serve(establisher EstablishFn) {
l.Lock()
l.Serving = true
l.Unlock()
for range l.done {
return
}
}
// Init initializes the listener.
func (l *MockListener) Init(log *zerolog.Logger) error {
if l.ErrListen {
return fmt.Errorf("listen failure")
}
l.Lock()
defer l.Unlock()
l.Listening = true
return nil
}
// ID returns the id of the mock listener.
func (l *MockListener) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *MockListener) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *MockListener) Protocol() string {
return "mock"
}
// Close closes the mock listener.
func (l *MockListener) Close(closer CloseFn) {
l.Lock()
defer l.Unlock()
l.Serving = false
closer(l.id)
close(l.done)
}
// IsServing indicates whether the mock listener is serving.
func (l *MockListener) IsServing() bool {
l.Lock()
defer l.Unlock()
return l.Serving
}
// IsListening indicates whether the mock listener is listening.
func (l *MockListener) IsListening() bool {
l.Lock()
defer l.Unlock()
return l.Listening
}

View File

@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestMockEstablisher(t *testing.T) {
_, w := net.Pipe()
err := MockEstablisher("t1", w)
require.NoError(t, err)
w.Close()
}
func TestNewMockListener(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.id)
require.Equal(t, testAddr, mocked.address)
}
func TestMockListenerID(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.ID())
}
func TestMockListenerAddress(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, testAddr, mocked.Address())
}
func TestMockListenerProtocol(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "mock", mocked.Protocol())
}
func TestNewMockListenerIsListening(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsListening())
}
func TestNewMockListenerIsServing(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsServing())
}
func TestNewMockListenerInit(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, "t1", mocked.id)
require.Equal(t, testAddr, mocked.address)
require.Equal(t, false, mocked.IsListening())
err := mocked.Init(nil)
require.NoError(t, err)
require.Equal(t, true, mocked.IsListening())
}
func TestNewMockListenerInitFailure(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
mocked.ErrListen = true
err := mocked.Init(nil)
require.Error(t, err)
}
func TestMockListenerServe(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
require.Equal(t, false, mocked.IsServing())
o := make(chan bool)
go func(o chan bool) {
mocked.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
require.Equal(t, true, mocked.IsServing())
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
mocked.Init(nil)
}
func TestMockListenerClose(t *testing.T) {
mocked := NewMockListener("t1", testAddr)
var closed bool
mocked.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}

View File

@ -0,0 +1,88 @@
package listeners
import (
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// Net is a listener for establishing client connections on basic TCP protocol.
type Net struct { // [MQTT-4.2.0-1]
mu sync.Mutex
listener net.Listener // a net.Listener which will listen for new clients
id string // the internal id of the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
func NewNet(id string, listener net.Listener) *Net {
return &Net{
id: id,
listener: listener,
}
}
// ID returns the id of the listener.
func (l *Net) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Net) Address() string {
return l.listener.Addr().String()
}
// Protocol returns the network of the listener.
func (l *Net) Protocol() string {
return l.listener.Addr().Network()
}
// Init initializes the listener.
func (l *Net) Init(log *zerolog.Logger) error {
l.log = log
return nil
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *Net) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listener.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *Net) Close(closeClients CloseFn) {
l.mu.Lock()
defer l.mu.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listener != nil {
err := l.listener.Close()
if err != nil {
return
}
}
}

View File

@ -0,0 +1,105 @@
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewNet(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.id)
}
func TestNetID(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "t1", l.ID())
}
func TestNetAddress(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, n.Addr().String(), l.Address())
}
func TestNetProtocol(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
require.Equal(t, "tcp", l.Protocol())
}
func TestNetInit(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
}
func TestNetServeAndClose(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestNetEstablishThenEnd(t *testing.T) {
n, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
l := NewNet("t1", n)
err = l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", n.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@ -0,0 +1,108 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"crypto/tls"
"net"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// TCP is a listener for establishing client connections on basic TCP protocol.
type TCP struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
listen net.Listener // a net.Listener which will listen for new clients
config *Config // configuration values for the listener
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once
}
// NewTCP initialises and returns a new TCP listener, listening on an address.
func NewTCP(id, address string, config *Config) *TCP {
if config == nil {
config = new(Config)
}
return &TCP{
id: id,
address: address,
config: config,
}
}
// ID returns the id of the listener.
func (l *TCP) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *TCP) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *TCP) Protocol() string {
return "tcp"
}
// Init initializes the listener.
func (l *TCP) Init(log *zerolog.Logger) error {
l.log = log
var err error
if l.config.TLSConfig != nil {
l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
} else {
l.listen, err = net.Listen("tcp", l.address)
}
return err
}
// Serve starts waiting for new TCP connections, and calls the establish
// connection callback for any received.
func (l *TCP) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *TCP) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@ -0,0 +1,131 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewTCP(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestTCPID(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "t1", l.ID())
}
func TestTCPAddress(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestTCPProtocol(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPProtocolTLS(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
l.Init(&logger)
defer l.listen.Close()
require.Equal(t, "tcp", l.Protocol())
}
func TestTCPInit(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewTCP("t2", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
require.NotNil(t, l2.config.TLSConfig)
}
func TestTCPServeAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestTCPServeTLSAndClose(t *testing.T) {
l := NewTCP("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
<-o
}
func TestTCPEstablishThenEnd(t *testing.T) {
l := NewTCP("t1", testAddr, nil)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("tcp", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"net"
"os"
"sync"
"sync/atomic"
"github.com/rs/zerolog"
)
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
type UnixSock struct {
sync.RWMutex
id string // the internal id of the listener.
address string // the network address to bind to.
listen net.Listener // a net.Listener which will listen for new clients.
log *zerolog.Logger // server logger
end uint32 // ensure the close methods are only called once.
}
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
func NewUnixSock(id, address string) *UnixSock {
return &UnixSock{
id: id,
address: address,
}
}
// ID returns the id of the listener.
func (l *UnixSock) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *UnixSock) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *UnixSock) Protocol() string {
return "unix"
}
// Init initializes the listener.
func (l *UnixSock) Init(log *zerolog.Logger) error {
l.log = log
var err error
_ = os.Remove(l.address)
l.listen, err = net.Listen("unix", l.address)
return err
}
// Serve starts waiting for new UnixSock connections, and calls the establish
// connection callback for any received.
func (l *UnixSock) Serve(establish EstablishFn) {
for {
if atomic.LoadUint32(&l.end) == 1 {
return
}
conn, err := l.listen.Accept()
if err != nil {
return
}
if atomic.LoadUint32(&l.end) == 0 {
go func() {
err = establish(l.id, conn)
if err != nil {
l.log.Warn().Err(err).Send()
}
}()
}
}
}
// Close closes the listener and any client connections.
func (l *UnixSock) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
closeClients(l.id)
}
if l.listen != nil {
err := l.listen.Close()
if err != nil {
return
}
}
}

View File

@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: jason@zgwit.com
package listeners
import (
"errors"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
const testUnixAddr = "mochi.sock"
func TestNewUnixSock(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.id)
require.Equal(t, testUnixAddr, l.address)
}
func TestUnixSockID(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "t1", l.ID())
}
func TestUnixSockAddress(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, testUnixAddr, l.Address())
}
func TestUnixSockProtocol(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
require.Equal(t, "unix", l.Protocol())
}
func TestUnixSockInit(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
l.Close(MockCloser)
require.NoError(t, err)
l2 := NewUnixSock("t2", testUnixAddr)
err = l2.Init(&logger)
l2.Close(MockCloser)
require.NoError(t, err)
}
func TestUnixSockServeAndClose(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
l.Close(MockCloser) // coverage: close closed
l.Serve(MockEstablisher) // coverage: serve closed
}
func TestUnixSockEstablishThenEnd(t *testing.T) {
l := NewUnixSock("t1", testUnixAddr)
err := l.Init(&logger)
require.NoError(t, err)
o := make(chan bool)
established := make(chan bool)
go func() {
l.Serve(func(id string, c net.Conn) error {
established <- true
return errors.New("ending") // return an error to exit immediately
})
o <- true
}()
time.Sleep(time.Millisecond)
net.Dial("unix", l.listen.Addr().String())
require.Equal(t, true, <-established)
l.Close(MockCloser)
<-o
}

View File

@ -0,0 +1,178 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var (
// ErrInvalidMessage indicates that a message payload was not valid.
ErrInvalidMessage = errors.New("message type not binary")
)
// Websocket is a listener for establishing websocket connections.
type Websocket struct { // [MQTT-4.2.0-1]
sync.RWMutex
id string // the internal id of the listener
address string // the network address to bind to
config *Config // configuration values for the listener
listen *http.Server // an http server for serving websocket connections
log *zerolog.Logger // server logger
establish EstablishFn // the server's establish connection handler
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
end uint32 // ensure the close methods are only called once
}
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
func NewWebsocket(id, address string, config *Config) *Websocket {
if config == nil {
config = new(Config)
}
return &Websocket{
id: id,
address: address,
config: config,
upgrader: &websocket.Upgrader{
Subprotocols: []string{"mqtt"},
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
// ID returns the id of the listener.
func (l *Websocket) ID() string {
return l.id
}
// Address returns the address of the listener.
func (l *Websocket) Address() string {
return l.address
}
// Protocol returns the address of the listener.
func (l *Websocket) Protocol() string {
if l.config.TLSConfig != nil {
return "wss"
}
return "ws"
}
// Init initializes the listener.
func (l *Websocket) Init(log *zerolog.Logger) error {
l.log = log
mux := http.NewServeMux()
mux.HandleFunc("/", l.handler)
l.listen = &http.Server{
Addr: l.address,
Handler: mux,
TLSConfig: l.config.TLSConfig,
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
}
return nil
}
// handler upgrades and handles an incoming websocket connection.
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
c, err := l.upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
if err != nil {
l.log.Warn().Err(err).Send()
}
}
// Serve starts waiting for new Websocket connections, and calls the connection
// establishment callback for any received.
func (l *Websocket) Serve(establish EstablishFn) {
l.establish = establish
if l.listen.TLSConfig != nil {
l.listen.ListenAndServeTLS("", "")
} else {
l.listen.ListenAndServe()
}
}
// Close closes the listener and any client connections.
func (l *Websocket) Close(closeClients CloseFn) {
l.Lock()
defer l.Unlock()
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
l.listen.Shutdown(ctx)
}
closeClients(l.id)
}
// wsConn is a websocket connection which satisfies the net.Conn interface.
type wsConn struct {
net.Conn
c *websocket.Conn
}
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
func (ws *wsConn) Read(p []byte) (int, error) {
op, r, err := ws.c.NextReader()
if err != nil {
return 0, err
}
if op != websocket.BinaryMessage {
err = ErrInvalidMessage
return 0, err
}
var n, br int
for {
br, err = r.Read(p[n:])
n += br
if err != nil {
if errors.Is(err, io.EOF) {
err = nil
}
return n, err
}
}
}
// Write writes bytes to the websocket connection.
func (ws *wsConn) Write(p []byte) (int, error) {
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
if err != nil {
return 0, err
}
return len(p), nil
}
// Close signals the underlying websocket conn to close.
func (ws *wsConn) Close() error {
return ws.Conn.Close()
}

View File

@ -0,0 +1,114 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package listeners
import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
)
func TestNewWebsocket(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "t1", l.id)
require.Equal(t, testAddr, l.address)
}
func TestWebsocketID(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "t1", l.ID())
}
func TestWebsocketAddress(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, testAddr, l.Address())
}
func TestWebsocketProtocol(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Equal(t, "ws", l.Protocol())
}
func TestWebsocketProtocoTLS(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
require.Equal(t, "wss", l.Protocol())
}
func TestWebsockeInit(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
require.Nil(t, l.listen)
err := l.Init(nil)
require.NoError(t, err)
require.NotNil(t, l.listen)
}
func TestWebsocketServeAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.True(t, closed)
<-o
}
func TestWebsocketServeTLSAndClose(t *testing.T) {
l := NewWebsocket("t1", testAddr, &Config{
TLSConfig: tlsConfigBasic,
})
err := l.Init(nil)
require.NoError(t, err)
o := make(chan bool)
go func(o chan bool) {
l.Serve(MockEstablisher)
o <- true
}(o)
time.Sleep(time.Millisecond)
var closed bool
l.Close(func(id string) {
closed = true
})
require.Equal(t, true, closed)
}
func TestWebsocketUpgrade(t *testing.T) {
l := NewWebsocket("t1", testAddr, nil)
l.Init(nil)
e := make(chan bool)
l.establish = func(id string, c net.Conn) error {
e <- true
return nil
}
s := httptest.NewServer(http.HandlerFunc(l.handler))
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
require.NoError(t, err)
require.Equal(t, true, <-e)
s.Close()
ws.Close()
}

View File

@ -0,0 +1,172 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"encoding/binary"
"io"
"unicode/utf8"
"unsafe"
)
// bytesToString provides a zero-alloc no-copy byte to string conversion.
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
func bytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// decodeUint16 extracts the value of two bytes from a byte array.
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
if len(buf) < offset+2 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
}
// decodeUint32 extracts the value of four bytes from a byte array.
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
if len(buf) < offset+4 {
return 0, 0, ErrMalformedOffsetUintOutOfRange
}
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
}
// decodeString extracts a string from a byte array, beginning at an offset.
func decodeString(buf []byte, offset int) (string, int, error) {
b, n, err := decodeBytes(buf, offset)
if err != nil {
return "", 0, err
}
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
return "", 0, ErrMalformedInvalidUTF8
}
return bytesToString(b), n, nil
}
// validUTF8 checks if the byte array contains valid UTF-8 characters.
func validUTF8(b []byte) bool {
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
}
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
length, next, err := decodeUint16(buf, offset)
if err != nil {
return make([]byte, 0), 0, err
}
if next+int(length) > len(buf) {
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
}
return buf[next : next+int(length)], next + int(length), nil
}
// decodeByte extracts the value of a byte from a byte array.
func decodeByte(buf []byte, offset int) (byte, int, error) {
if len(buf) <= offset {
return 0, 0, ErrMalformedOffsetByteOutOfRange
}
return buf[offset], offset + 1, nil
}
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
if len(buf) <= offset {
return false, 0, ErrMalformedOffsetBoolOutOfRange
}
return 1&buf[offset] > 0, offset + 1, nil
}
// encodeBool returns a byte instead of a bool.
func encodeBool(b bool) byte {
if b {
return 1
}
return 0
}
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
func encodeBytes(val []byte) []byte {
// In most circumstances the number of bytes being encoded is small.
// Setting the cap to a low amount allows us to account for those without
// triggering allocation growth on append unless we need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, val...)
}
// encodeUint16 encodes a uint16 value to a byte array.
func encodeUint16(val uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, val)
return buf
}
// encodeUint32 encodes a uint16 value to a byte array.
func encodeUint32(val uint32) []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, val)
return buf
}
// encodeString encodes a string to a byte array.
func encodeString(val string) []byte {
// Like encodeBytes, we set the cap to a small number to avoid
// triggering allocation growth on append unless we absolutely need to.
buf := make([]byte, 2, 32)
binary.BigEndian.PutUint16(buf, uint16(len(val)))
return append(buf, []byte(val)...)
}
// encodeLength writes length bits for the header.
func encodeLength(b *bytes.Buffer, length int64) {
// 1.5.5 Variable Byte Integer encode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
for {
eb := byte(length % 128)
length /= 128
if length > 0 {
eb |= 0x80
}
b.WriteByte(eb)
if length == 0 {
break // [MQTT-1.5.5-1]
}
}
}
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
// see 1.5.5 Variable Byte Integer decode non-normative
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
var multiplier uint32
var value uint32
bu = 1
for {
eb, err := b.ReadByte()
if err != nil {
return 0, bu, err
}
value |= uint32(eb&127) << multiplier
if value > 268435455 {
return 0, bu, ErrMalformedVariableByteInteger
}
if (eb & 128) == 0 {
break
}
multiplier += 7
bu++
}
return int(value), bu, nil
}

View File

@ -0,0 +1,422 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"errors"
"fmt"
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBytesToString(t *testing.T) {
b := []byte{'a', 'b', 'c'}
require.Equal(t, "abc", bytesToString(b))
}
func TestDecodeString(t *testing.T) {
expect := []struct {
name string
rawBytes []byte
result string
offset int
shouldFail error
}{
{
offset: 0,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: "a/b/c/d",
},
{
offset: 14,
rawBytes: []byte{
Connect << 4, 17, // Fixed header
0, 6, // Protocol Name - MSB+LSB
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
3, // Protocol Version
0, // Packet Flags
0, 30, // Keepalive
0, 3, // Client ID - MSB+LSB
'h', 'e', 'y', // Client ID "zen"},
},
result: "hey",
},
{
offset: 2,
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
result: "1/2/3/4/a/b/c/d/e/^/@/!",
},
{
offset: 0,
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
result: "x/y/z",
},
{
offset: 0,
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 5,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 9,
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
{
offset: 17,
rawBytes: []byte{
Connect << 4, 0, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
4, // Protocol Version
0, // Flags
0, 20, // Keepalive
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
0, 6, // Will Topic - MSB+LSB
'l',
},
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
offset: 0,
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
shouldFail: ErrMalformedInvalidUTF8,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
require.NoError(t, err)
require.Equal(t, "\ufeff", result)
}
func TestDecodeBytes(t *testing.T) {
expect := []struct {
rawBytes []byte
result []uint8
next int
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
result: []byte{0x4d, 0x51, 0x54, 0x54},
next: 6,
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 0,
shouldFail: ErrMalformedOffsetBytesOutOfRange,
},
{
rawBytes: []byte{0, 4, 77, 81},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
})
}
}
func TestDecodeByte(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint8
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
result: uint8(0x00),
offset: 0,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x04),
offset: 1,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x4d),
offset: 2,
},
{
rawBytes: []byte{0, 4, 77, 81, 84, 84},
result: uint8(0x51),
offset: 3,
},
{
rawBytes: []byte{0, 4, 77, 80, 82, 84},
offset: 8,
shouldFail: ErrMalformedOffsetByteOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+1, offset)
})
}
}
func TestDecodeUint16(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint16
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x07),
offset: 0,
},
{
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
result: uint16(0x761),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+2, offset)
})
}
}
func TestDecodeUint32(t *testing.T) {
expect := []struct {
rawBytes []byte
result uint32
offset int
shouldFail error
}{
{
rawBytes: []byte{0, 0, 0, 7, 8},
result: uint32(7),
offset: 0,
},
{
rawBytes: []byte{0, 0, 1, 226, 64, 8},
result: uint32(123456),
offset: 1,
},
{
rawBytes: []byte{0, 7, 255, 47},
offset: 8,
shouldFail: ErrMalformedOffsetUintOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, i+4, offset)
})
}
}
func TestDecodeByteBool(t *testing.T) {
expect := []struct {
rawBytes []byte
result bool
offset int
shouldFail error
}{
{
rawBytes: []byte{0x00, 0x00},
result: false,
},
{
rawBytes: []byte{0x01, 0x00},
result: true,
},
{
rawBytes: []byte{0x01, 0x00},
offset: 5,
shouldFail: ErrMalformedOffsetBoolOutOfRange,
},
}
for i, wanted := range expect {
t.Run(fmt.Sprint(i), func(t *testing.T) {
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
if wanted.shouldFail != nil {
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
return
}
require.NoError(t, err)
require.Equal(t, wanted.result, result)
require.Equal(t, 1, offset)
})
}
}
func TestDecodeLength(t *testing.T) {
b := bytes.NewBuffer([]byte{0x78})
n, bu, err := DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 120, n)
require.Equal(t, 1, bu)
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
n, bu, err = DecodeLength(b)
require.NoError(t, err)
require.Equal(t, 268435455, n)
require.Equal(t, 4, bu)
}
func TestDecodeLengthErrors(t *testing.T) {
b := bytes.NewBuffer([]byte{})
_, _, err := DecodeLength(b)
require.Error(t, err)
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
_, _, err = DecodeLength(b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
}
func TestEncodeBool(t *testing.T) {
result := encodeBool(true)
require.Equal(t, byte(1), result)
result = encodeBool(false)
require.Equal(t, byte(0), result)
// Check failure.
result = encodeBool(false)
require.NotEqual(t, byte(1), result)
}
func TestEncodeBytes(t *testing.T) {
result := encodeBytes([]byte("testing"))
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
result = encodeBytes([]byte("testing"))
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
}
func TestEncodeUint16(t *testing.T) {
result := encodeUint16(0)
require.Equal(t, []byte{0x00, 0x00}, result)
result = encodeUint16(32767)
require.Equal(t, []byte{0x7f, 0xff}, result)
result = encodeUint16(math.MaxUint16)
require.Equal(t, []byte{0xff, 0xff}, result)
}
func TestEncodeUint32(t *testing.T) {
result := encodeUint32(7)
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
result = encodeUint32(32767)
require.Equal(t, []byte{0, 0, 127, 255}, result)
result = encodeUint32(math.MaxUint32)
require.Equal(t, []byte{255, 255, 255, 255}, result)
}
func TestEncodeString(t *testing.T) {
result := encodeString("testing")
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
result = encodeString("")
require.Equal(t, []uint8{0x00, 0x00}, result)
result = encodeString("a")
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
result = encodeString("b")
require.NotEqual(t, []uint8{0x00, 0x00}, result)
}
func TestEncodeLength(t *testing.T) {
b := new(bytes.Buffer)
encodeLength(b, 120)
require.Equal(t, []byte{0x78}, b.Bytes())
b = new(bytes.Buffer)
encodeLength(b, math.MaxInt64)
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
}
func TestValidUTF8(t *testing.T) {
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
require.False(t, validUTF8([]byte{0xff, 0xff}))
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
}

View File

@ -0,0 +1,147 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
// Code contains a reason code and reason string for a response.
type Code struct {
Reason string
Code byte
}
// String returns the readable reason for a code.
func (c Code) String() string {
return c.Reason
}
// Error returns the readable reason for a code.
func (c Code) Error() string {
return c.Reason
}
var (
// QosCodes indicicates the reason codes for each Qos byte.
QosCodes = map[byte]Code{
0: CodeGrantedQos0,
1: CodeGrantedQos1,
2: CodeGrantedQos2,
}
CodeSuccess = Code{Code: 0x00, Reason: "success"}
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
// MQTTv3 specific bytes.
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
Err3ClientIdentifierNotValid = Code{Code: 0x02}
Err3ServerUnavailable = Code{Code: 0x03}
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
Err3NotAuthorized = Code{Code: 0x05}
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
// This is required because MQTTv3 has different return byte specification.
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
V5CodesToV3 = map[Code]Code{
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
ErrServerUnavailable: Err3ServerUnavailable,
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
ErrBadUsernameOrPassword: Err3NotAuthorized,
}
)

View File

@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCodesString(t *testing.T) {
c := Code{
Reason: "test",
Code: 0x1,
}
require.Equal(t, "test", c.String())
}
func TestCodesErrorr(t *testing.T) {
c := Code{
Reason: "error",
Code: 0x1,
}
require.Equal(t, "error", error(c).Error())
}

View File

@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
)
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
type FixedHeader struct {
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
Qos byte `json:"qos"` // indicates the quality of service expected.
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
Retain bool `json:"retain"` // whether the message should be retained.
}
// Encode encodes the FixedHeader and returns a bytes buffer.
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
encodeLength(buf, int64(fh.Remaining))
}
// Decode extracts the specification bits from the header byte.
func (fh *FixedHeader) Decode(hb byte) error {
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
switch fh.Type {
case Publish:
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
}
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
fh.Qos = (hb >> 1) & 0x03 // qos flag
fh.Retain = hb&0x01 > 0 // is retain flag
case Pubrel:
fallthrough
case Subscribe:
fallthrough
case Unsubscribe:
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
return ErrMalformedFlags
}
fh.Qos = (hb >> 1) & 0x03
default:
if (hb>>0)&0x01 != 0 ||
(hb>>1)&0x01 != 0 ||
(hb>>2)&0x01 != 0 ||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
return ErrMalformedFlags
}
}
if fh.Qos == 0 && fh.Dup {
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
}
return nil
}

View File

@ -0,0 +1,237 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
type fixedHeaderTable struct {
desc string
rawBytes []byte
header FixedHeader
packetError bool
expect error
}
var fixedHeaderExpected = []fixedHeaderTable{
{
desc: "connect",
rawBytes: []byte{Connect << 4, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "connack",
rawBytes: []byte{Connack << 4, 0x00},
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish",
rawBytes: []byte{Publish << 4, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1",
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "publish qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish qos 2",
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
},
{
desc: "publish qos 2 retain",
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 0",
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 0 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
expect: ErrProtocolViolationDupNoQos,
},
{
desc: "publish dup qos 1 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
},
{
desc: "publish dup qos 2 retain",
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
},
{
desc: "puback",
rawBytes: []byte{Puback << 4, 0x00},
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrec",
rawBytes: []byte{Pubrec << 4, 0x00},
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pubrel",
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "pubcomp",
rawBytes: []byte{Pubcomp << 4, 0x00},
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "subscribe",
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "suback",
rawBytes: []byte{Suback << 4, 0x00},
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "unsubscribe",
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
},
{
desc: "unsuback",
rawBytes: []byte{Unsuback << 4, 0x00},
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingreq",
rawBytes: []byte{Pingreq << 4, 0x00},
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "pingresp",
rawBytes: []byte{Pingresp << 4, 0x00},
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "disconnect",
rawBytes: []byte{Disconnect << 4, 0x00},
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
{
desc: "auth",
rawBytes: []byte{Auth << 4, 0x00},
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
},
// remaining length
{
desc: "remaining length 10",
rawBytes: []byte{Publish << 4, 0x0a},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
},
{
desc: "remaining length 512",
rawBytes: []byte{Publish << 4, 0x80, 0x04},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
},
{
desc: "remaining length 978",
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
},
{
desc: "remaining length 20202",
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
},
{
desc: "remaining length oversize",
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
packetError: true,
},
// Invalid flags for packet
{
desc: "invalid type dup is true",
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type qos is 1",
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid type retain is true",
rawBytes: []byte{Connect<<4 | 1, 0x00},
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
expect: ErrMalformedFlags,
},
{
desc: "invalid publish qos bits 1 + 2 set",
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
header: FixedHeader{Type: Publish},
expect: ErrProtocolViolationQosOutOfRange,
},
{
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
header: FixedHeader{Type: Pubrel, Qos: 1},
expect: ErrMalformedFlags,
},
{
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
header: FixedHeader{Type: Subscribe, Qos: 1},
expect: ErrMalformedFlags,
},
}
func TestFixedHeaderEncode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
buf := new(bytes.Buffer)
wanted.header.Encode(buf)
if wanted.expect == nil {
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
}
})
}
}
func TestFixedHeaderDecode(t *testing.T) {
for _, wanted := range fixedHeaderExpected {
t.Run(wanted.desc, func(t *testing.T) {
fh := new(FixedHeader)
err := fh.Decode(wanted.rawBytes[0])
if wanted.expect != nil {
require.Equal(t, wanted.expect, err)
} else {
require.NoError(t, err)
require.Equal(t, wanted.header.Type, fh.Type)
require.Equal(t, wanted.header.Dup, fh.Dup)
require.Equal(t, wanted.header.Qos, fh.Qos)
require.Equal(t, wanted.header.Retain, fh.Retain)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,505 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"testing"
"github.com/jinzhu/copier"
"github.com/stretchr/testify/require"
)
const pkInfo = "packet type %v, %s"
var packetList = []byte{
Connect,
Connack,
Publish,
Puback,
Pubrec,
Pubrel,
Pubcomp,
Subscribe,
Suback,
Unsubscribe,
Unsuback,
Pingreq,
Pingresp,
Disconnect,
Auth,
}
var pkTable = []TPacketCase{
TPacketData[Connect].Get(TConnectMqtt311),
TPacketData[Connect].Get(TConnectMqtt5),
TPacketData[Connect].Get(TConnectUserPassLWT),
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
TPacketData[Connack].Get(TConnackAcceptedNoSession),
TPacketData[Publish].Get(TPublishBasic),
TPacketData[Publish].Get(TPublishMqtt5),
TPacketData[Puback].Get(TPuback),
TPacketData[Pubrec].Get(TPubrec),
TPacketData[Pubrel].Get(TPubrel),
TPacketData[Pubcomp].Get(TPubcomp),
TPacketData[Subscribe].Get(TSubscribe),
TPacketData[Subscribe].Get(TSubscribeMqtt5),
TPacketData[Suback].Get(TSuback),
TPacketData[Unsubscribe].Get(TUnsubscribe),
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
TPacketData[Pingreq].Get(TPingreq),
TPacketData[Pingresp].Get(TPingresp),
TPacketData[Disconnect].Get(TDisconnect),
TPacketData[Disconnect].Get(TDisconnectMqtt5),
}
func TestNewPackets(t *testing.T) {
s := NewPackets()
require.NotNil(t, s.internal)
}
func TestPacketsAdd(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{})
require.Contains(t, s.internal, "cl1")
}
func TestPacketsGet(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
pk, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, "a1", pk.TopicName)
}
func TestPacketsGetAll(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
s.Add("cl3", Packet{TopicName: "a3"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestPacketsLen(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
s.Add("cl2", Packet{TopicName: "a2"})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSPacketsDelete(t *testing.T) {
s := NewPackets()
s.Add("cl1", Packet{TopicName: "a1"})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestFormatPacketID(t *testing.T) {
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
packet := &Packet{PacketID: id}
require.Equal(t, fmt.Sprint(id), packet.FormatID())
}
}
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
p := &Subscription{
Qos: 2,
NoLocal: true,
RetainAsPublished: true,
RetainHandling: 2,
}
x := new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
p = &Subscription{
Qos: 1,
NoLocal: false,
RetainAsPublished: false,
RetainHandling: 1,
}
x = new(Subscription)
x.decode(p.encode())
require.Equal(t, *p, *x)
}
func TestPacketEncode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !encodeTestOK(wanted) {
return
}
pk := new(Packet)
copier.Copy(pk, wanted.Packet)
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
pk.Mods.AllowResponseInfo = true
buf := new(bytes.Buffer)
var err error
switch pkt {
case Connect:
err = pk.ConnectEncode(buf)
case Connack:
err = pk.ConnackEncode(buf)
case Publish:
err = pk.PublishEncode(buf)
case Puback:
err = pk.PubackEncode(buf)
case Pubrec:
err = pk.PubrecEncode(buf)
case Pubrel:
err = pk.PubrelEncode(buf)
case Pubcomp:
err = pk.PubcompEncode(buf)
case Subscribe:
err = pk.SubscribeEncode(buf)
case Suback:
err = pk.SubackEncode(buf)
case Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case Unsuback:
err = pk.UnsubackEncode(buf)
case Pingreq:
err = pk.PingreqEncode(buf)
case Pingresp:
err = pk.PingrespEncode(buf)
case Disconnect:
err = pk.DisconnectEncode(buf)
case Auth:
err = pk.AuthEncode(buf)
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
encoded := buf.Bytes()
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
if len(wanted.ActualBytes) > 0 {
wanted.RawBytes = wanted.ActualBytes
}
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
})
}
}
}
func TestPacketDecode(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if !decodeTestOK(wanted) {
return
}
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
pk.Mods.AllowResponseInfo = true
pk.FixedHeader.Decode(wanted.RawBytes[0])
if len(wanted.RawBytes) > 0 {
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
}
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
}
buf := wanted.RawBytes[2:]
var err error
switch pkt {
case Connect:
err = pk.ConnectDecode(buf)
case Connack:
err = pk.ConnackDecode(buf)
case Publish:
err = pk.PublishDecode(buf)
case Puback:
err = pk.PubackDecode(buf)
case Pubrec:
err = pk.PubrecDecode(buf)
case Pubrel:
err = pk.PubrelDecode(buf)
case Pubcomp:
err = pk.PubcompDecode(buf)
case Subscribe:
err = pk.SubscribeDecode(buf)
case Suback:
err = pk.SubackDecode(buf)
case Unsubscribe:
err = pk.UnsubscribeDecode(buf)
case Unsuback:
err = pk.UnsubackDecode(buf)
case Pingreq:
err = pk.PingreqDecode(buf)
case Pingresp:
err = pk.PingrespDecode(buf)
case Disconnect:
err = pk.DisconnectDecode(buf)
case Auth:
err = pk.AuthDecode(buf)
}
if wanted.FailFirst != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
return
}
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
if pkt == Connect {
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
// against it unless it's a connect packet.
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
}
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
})
}
}
}
func TestValidate(t *testing.T) {
for _, pkt := range packetList {
require.Contains(t, TPacketData, pkt)
for _, wanted := range TPacketData[pkt] {
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
if wanted.Group == "validate" || wanted.Primary {
pk := wanted.Packet
var err error
switch pkt {
case Connect:
err = pk.ConnectValidate()
case Publish:
err = pk.PublishValidate(1024)
case Subscribe:
err = pk.SubscribeValidate()
case Unsubscribe:
err = pk.UnsubscribeValidate()
case Auth:
err = pk.AuthValidate()
}
if wanted.Expect != nil {
require.Error(t, err, pkInfo, pkt, wanted.Desc)
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
}
}
})
}
}
}
func TestAckValidatePubrec(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoMatchingSubscribers.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicNameInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrPayloadFormatInvalid.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubrel(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidatePubcomp(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
ErrPacketIdentifierNotFound.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateSuback(t *testing.T) {
for _, b := range []byte{
CodeGrantedQos0.Code,
CodeGrantedQos1.Code,
CodeGrantedQos2.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
ErrQuotaExceeded.Code,
ErrSharedSubscriptionsNotSupported.Code,
ErrSubscriptionIdentifiersNotSupported.Code,
ErrWildcardSubscriptionsNotSupported.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestAckValidateUnsuback(t *testing.T) {
for _, b := range []byte{
CodeSuccess.Code,
CodeNoSubscriptionExisted.Code,
ErrUnspecifiedError.Code,
ErrImplementationSpecificError.Code,
ErrNotAuthorized.Code,
ErrTopicFilterInvalid.Code,
ErrPacketIdentifierInUse.Code,
} {
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
require.True(t, pk.ReasonCodeValid())
}
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
require.False(t, pk.ReasonCodeValid())
}
func TestReasonCodeValidMisc(t *testing.T) {
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
require.True(t, pk.ReasonCodeValid())
}
func TestCopy(t *testing.T) {
for _, tt := range pkTable {
pkc := tt.Packet.Copy(true)
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
pkcc := tt.Packet.Copy(false)
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
}
}
func TestMergeSubscription(t *testing.T) {
sub := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 0,
RetainAsPublished: false,
NoLocal: false,
Identifier: 1,
}
sub2 := Subscription{
Filter: "a/b/d",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 2,
}
expect := Subscription{
Filter: "a/b/c",
RetainHandling: 0,
Qos: 2,
RetainAsPublished: false,
NoLocal: true,
Identifier: 1,
Identifiers: map[string]int{
"a/b/c": 1,
"a/b/d": 2,
},
}
require.Equal(t, expect, sub.Merge(sub2))
}

View File

@ -0,0 +1,479 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"fmt"
"strings"
)
const (
PropPayloadFormat byte = 1
PropMessageExpiryInterval byte = 2
PropContentType byte = 3
PropResponseTopic byte = 8
PropCorrelationData byte = 9
PropSubscriptionIdentifier byte = 11
PropSessionExpiryInterval byte = 17
PropAssignedClientID byte = 18
PropServerKeepAlive byte = 19
PropAuthenticationMethod byte = 21
PropAuthenticationData byte = 22
PropRequestProblemInfo byte = 23
PropWillDelayInterval byte = 24
PropRequestResponseInfo byte = 25
PropResponseInfo byte = 26
PropServerReference byte = 28
PropReasonString byte = 31
PropReceiveMaximum byte = 33
PropTopicAliasMaximum byte = 34
PropTopicAlias byte = 35
PropMaximumQos byte = 36
PropRetainAvailable byte = 37
PropUser byte = 38
PropMaximumPacketSize byte = 39
PropWildcardSubAvailable byte = 40
PropSubIDAvailable byte = 41
PropSharedSubAvailable byte = 42
)
// validPacketProperties indicates which properties are valid for which packet types.
var validPacketProperties = map[byte]map[byte]byte{
PropPayloadFormat: {Publish: 1},
PropMessageExpiryInterval: {Publish: 1},
PropContentType: {Publish: 1},
PropResponseTopic: {Publish: 1},
PropCorrelationData: {Publish: 1},
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
PropAssignedClientID: {Connack: 1},
PropServerKeepAlive: {Connack: 1},
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
PropRequestProblemInfo: {Connect: 1},
PropWillDelayInterval: {Connect: 1},
PropRequestResponseInfo: {Connect: 1},
PropResponseInfo: {Connack: 1},
PropServerReference: {Connack: 1, Disconnect: 1},
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropReceiveMaximum: {Connect: 1, Connack: 1},
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
PropTopicAlias: {Publish: 1},
PropMaximumQos: {Connack: 1},
PropRetainAvailable: {Connack: 1},
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
PropMaximumPacketSize: {Connect: 1, Connack: 1},
PropWildcardSubAvailable: {Connack: 1},
PropSubIDAvailable: {Connack: 1},
PropSharedSubAvailable: {Connack: 1},
}
// UserProperty is an arbitrary key-value pair for a packet user properties array.
type UserProperty struct { // [MQTT-1.5.7-1]
Key string `json:"k"`
Val string `json:"v"`
}
// Properties contains all of the mqtt v5 properties available for a packet.
// Some properties have valid values of 0 or not-present. In this case, we opt for
// property flags to indicate the usage of property.
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
type Properties struct {
CorrelationData []byte `json:"cd"`
SubscriptionIdentifier []int `json:"si"`
AuthenticationData []byte `json:"ad"`
User []UserProperty `json:"user"`
ContentType string `json:"ct"`
ResponseTopic string `json:"rt"`
AssignedClientID string `json:"aci"`
AuthenticationMethod string `json:"am"`
ResponseInfo string `json:"ri"`
ServerReference string `json:"sr"`
ReasonString string `json:"rs"`
MessageExpiryInterval uint32 `json:"me"`
SessionExpiryInterval uint32 `json:"sei"`
WillDelayInterval uint32 `json:"wdi"`
MaximumPacketSize uint32 `json:"mps"`
ServerKeepAlive uint16 `json:"ska"`
ReceiveMaximum uint16 `json:"rm"`
TopicAliasMaximum uint16 `json:"tam"`
TopicAlias uint16 `json:"ta"`
PayloadFormat byte `json:"pf"`
PayloadFormatFlag bool `json:"fpf"`
SessionExpiryIntervalFlag bool `json:"fsei"`
ServerKeepAliveFlag bool `json:"fska"`
RequestProblemInfo byte `json:"rpi"`
RequestProblemInfoFlag bool `json:"frpi"`
RequestResponseInfo byte `json:"rri"`
TopicAliasFlag bool `json:"fta"`
MaximumQos byte `json:"mqos"`
MaximumQosFlag bool `json:"fmqos"`
RetainAvailable byte `json:"ra"`
RetainAvailableFlag bool `json:"fra"`
WildcardSubAvailable byte `json:"wsa"`
WildcardSubAvailableFlag bool `json:"fwsa"`
SubIDAvailable byte `json:"sida"`
SubIDAvailableFlag bool `json:"fsida"`
SharedSubAvailable byte `json:"ssa"`
SharedSubAvailableFlag bool `json:"fssa"`
}
// Copy creates a new Properties struct with copies of the values.
func (p *Properties) Copy(allowTransfer bool) Properties {
pr := Properties{
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
PayloadFormatFlag: p.PayloadFormatFlag,
MessageExpiryInterval: p.MessageExpiryInterval,
ContentType: p.ContentType, // [MQTT-3.3.2-20]
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
SessionExpiryInterval: p.SessionExpiryInterval,
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
AssignedClientID: p.AssignedClientID,
ServerKeepAlive: p.ServerKeepAlive,
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
AuthenticationMethod: p.AuthenticationMethod,
RequestProblemInfo: p.RequestProblemInfo,
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
WillDelayInterval: p.WillDelayInterval,
RequestResponseInfo: p.RequestResponseInfo,
ResponseInfo: p.ResponseInfo,
ServerReference: p.ServerReference,
ReasonString: p.ReasonString,
ReceiveMaximum: p.ReceiveMaximum,
TopicAliasMaximum: p.TopicAliasMaximum,
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
MaximumQos: p.MaximumQos,
MaximumQosFlag: p.MaximumQosFlag,
RetainAvailable: p.RetainAvailable,
RetainAvailableFlag: p.RetainAvailableFlag,
MaximumPacketSize: p.MaximumPacketSize,
WildcardSubAvailable: p.WildcardSubAvailable,
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
SubIDAvailable: p.SubIDAvailable,
SubIDAvailableFlag: p.SubIDAvailableFlag,
SharedSubAvailable: p.SharedSubAvailable,
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
}
if allowTransfer {
pr.TopicAlias = p.TopicAlias
pr.TopicAliasFlag = p.TopicAliasFlag
}
if len(p.CorrelationData) > 0 {
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
}
if len(p.SubscriptionIdentifier) > 0 {
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
}
if len(p.AuthenticationData) > 0 {
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
}
if len(p.User) > 0 {
pr.User = []UserProperty{}
for _, v := range p.User {
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
Key: v.Key,
Val: v.Val,
})
}
}
return pr
}
// canEncode returns true if the property type is valid for the packet type.
func (p *Properties) canEncode(pkt byte, k byte) bool {
return validPacketProperties[k][pkt] == 1
}
// Encode encodes properties into a bytes buffer.
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
if p == nil {
return
}
var buf bytes.Buffer
pkt := pk.FixedHeader.Type
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
buf.WriteByte(PropPayloadFormat)
buf.WriteByte(p.PayloadFormat)
}
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
buf.WriteByte(PropMessageExpiryInterval)
buf.Write(encodeUint32(p.MessageExpiryInterval))
}
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
buf.WriteByte(PropContentType)
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseTopic)
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropCorrelationData)
buf.Write(encodeBytes(p.CorrelationData))
}
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
for _, v := range p.SubscriptionIdentifier {
if v > 0 {
buf.WriteByte(PropSubscriptionIdentifier)
encodeLength(&buf, int64(v))
}
}
}
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
buf.WriteByte(PropSessionExpiryInterval)
buf.Write(encodeUint32(p.SessionExpiryInterval))
}
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
buf.WriteByte(PropAssignedClientID)
buf.Write(encodeString(p.AssignedClientID))
}
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
buf.WriteByte(PropServerKeepAlive)
buf.Write(encodeUint16(p.ServerKeepAlive))
}
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
buf.WriteByte(PropAuthenticationMethod)
buf.Write(encodeString(p.AuthenticationMethod))
}
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
buf.WriteByte(PropAuthenticationData)
buf.Write(encodeBytes(p.AuthenticationData))
}
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
buf.WriteByte(PropRequestProblemInfo)
buf.WriteByte(p.RequestProblemInfo)
}
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
buf.WriteByte(PropWillDelayInterval)
buf.Write(encodeUint32(p.WillDelayInterval))
}
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
buf.WriteByte(PropRequestResponseInfo)
buf.WriteByte(p.RequestResponseInfo)
}
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
buf.WriteByte(PropResponseInfo)
buf.Write(encodeString(p.ResponseInfo))
}
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
buf.WriteByte(PropServerReference)
buf.Write(encodeString(p.ServerReference))
}
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
b := encodeString(p.ReasonString)
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
buf.WriteByte(PropReasonString)
buf.Write(b)
}
}
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
buf.WriteByte(PropReceiveMaximum)
buf.Write(encodeUint16(p.ReceiveMaximum))
}
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
buf.WriteByte(PropTopicAliasMaximum)
buf.Write(encodeUint16(p.TopicAliasMaximum))
}
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
buf.WriteByte(PropTopicAlias)
buf.Write(encodeUint16(p.TopicAlias))
}
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
buf.WriteByte(PropMaximumQos)
buf.WriteByte(p.MaximumQos)
}
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
buf.WriteByte(PropRetainAvailable)
buf.WriteByte(p.RetainAvailable)
}
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
pb := bytes.NewBuffer([]byte{})
for _, v := range p.User {
pb.WriteByte(PropUser)
pb.Write(encodeString(v.Key))
pb.Write(encodeString(v.Val))
}
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
buf.Write(pb.Bytes())
}
}
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
buf.WriteByte(PropMaximumPacketSize)
buf.Write(encodeUint32(p.MaximumPacketSize))
}
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
buf.WriteByte(PropWildcardSubAvailable)
buf.WriteByte(p.WildcardSubAvailable)
}
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
buf.WriteByte(PropSubIDAvailable)
buf.WriteByte(p.SubIDAvailable)
}
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
buf.WriteByte(PropSharedSubAvailable)
buf.WriteByte(p.SharedSubAvailable)
}
encodeLength(b, int64(buf.Len()))
buf.WriteTo(b) // [MQTT-3.1.3-10]
}
// Decode decodes property bytes into a properties struct.
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
if p == nil {
return 0, nil
}
var bu int
n, bu, err = DecodeLength(b)
if err != nil {
return n + bu, err
}
if n == 0 {
return n + bu, nil
}
bt := b.Bytes()
var k byte
for offset := 0; offset < n; {
k, offset, err = decodeByte(bt, offset)
if err != nil {
return n + bu, err
}
if _, ok := validPacketProperties[k][pk]; !ok {
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
}
switch k {
case PropPayloadFormat:
p.PayloadFormat, offset, err = decodeByte(bt, offset)
p.PayloadFormatFlag = true
case PropMessageExpiryInterval:
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
case PropContentType:
p.ContentType, offset, err = decodeString(bt, offset)
case PropResponseTopic:
p.ResponseTopic, offset, err = decodeString(bt, offset)
case PropCorrelationData:
p.CorrelationData, offset, err = decodeBytes(bt, offset)
case PropSubscriptionIdentifier:
if p.SubscriptionIdentifier == nil {
p.SubscriptionIdentifier = []int{}
}
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
if err != nil {
return n + bu, err
}
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
offset += bu
case PropSessionExpiryInterval:
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
p.SessionExpiryIntervalFlag = true
case PropAssignedClientID:
p.AssignedClientID, offset, err = decodeString(bt, offset)
case PropServerKeepAlive:
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
p.ServerKeepAliveFlag = true
case PropAuthenticationMethod:
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
case PropAuthenticationData:
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
case PropRequestProblemInfo:
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
p.RequestProblemInfoFlag = true
case PropWillDelayInterval:
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
case PropRequestResponseInfo:
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
case PropResponseInfo:
p.ResponseInfo, offset, err = decodeString(bt, offset)
case PropServerReference:
p.ServerReference, offset, err = decodeString(bt, offset)
case PropReasonString:
p.ReasonString, offset, err = decodeString(bt, offset)
case PropReceiveMaximum:
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAliasMaximum:
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
case PropTopicAlias:
p.TopicAlias, offset, err = decodeUint16(bt, offset)
p.TopicAliasFlag = true
case PropMaximumQos:
p.MaximumQos, offset, err = decodeByte(bt, offset)
p.MaximumQosFlag = true
case PropRetainAvailable:
p.RetainAvailable, offset, err = decodeByte(bt, offset)
p.RetainAvailableFlag = true
case PropUser:
var k, v string
k, offset, err = decodeString(bt, offset)
if err != nil {
return n + bu, err
}
v, offset, err = decodeString(bt, offset)
p.User = append(p.User, UserProperty{Key: k, Val: v})
case PropMaximumPacketSize:
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
case PropWildcardSubAvailable:
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
p.WildcardSubAvailableFlag = true
case PropSubIDAvailable:
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
p.SubIDAvailableFlag = true
case PropSharedSubAvailable:
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
p.SharedSubAvailableFlag = true
}
if err != nil {
return n + bu, err
}
}
return n + bu, nil
}

View File

@ -0,0 +1,333 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
var (
propertiesStruct = Properties{
PayloadFormat: byte(1), // UTF-8 Format
PayloadFormatFlag: true,
MessageExpiryInterval: uint32(2),
ContentType: "text/plain",
ResponseTopic: "a/b/c",
CorrelationData: []byte("data"),
SubscriptionIdentifier: []int{322122},
SessionExpiryInterval: uint32(120),
SessionExpiryIntervalFlag: true,
AssignedClientID: "mochi-v5",
ServerKeepAlive: uint16(20),
ServerKeepAliveFlag: true,
AuthenticationMethod: "SHA-1",
AuthenticationData: []byte("auth-data"),
RequestProblemInfo: byte(1),
RequestProblemInfoFlag: true,
WillDelayInterval: uint32(600),
RequestResponseInfo: byte(1),
ResponseInfo: "response",
ServerReference: "mochi-2",
ReasonString: "reason",
ReceiveMaximum: uint16(500),
TopicAliasMaximum: uint16(999),
TopicAlias: uint16(3),
TopicAliasFlag: true,
MaximumQos: byte(1),
MaximumQosFlag: true,
RetainAvailable: byte(1),
RetainAvailableFlag: true,
User: []UserProperty{
{
Key: "hello",
Val: "世界",
},
{
Key: "key2",
Val: "value2",
},
},
MaximumPacketSize: uint32(32000),
WildcardSubAvailable: byte(1),
WildcardSubAvailableFlag: true,
SubIDAvailable: byte(1),
SubIDAvailableFlag: true,
SharedSubAvailable: byte(1),
SharedSubAvailableFlag: true,
}
propertiesBytes = []byte{
172, 1, // VBI
// Payload Format (1) (vbi:2)
1, 1,
// Message Expiry (2) (vbi:7)
2, 0, 0, 0, 2,
// Content Type (3) (vbi:20)
3,
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
// Response Topic (8) (vbi:28)
8,
0, 5, 'a', '/', 'b', '/', 'c',
// Correlations Data (9) (vbi:35)
9,
0, 4, 'd', 'a', 't', 'a',
// Subscription Identifier (11) (vbi:39)
11,
202, 212, 19,
// Session Expiry Interval (17) (vbi:43)
17,
0, 0, 0, 120,
// Assigned Client ID (18) (vbi:55)
18,
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
// Server Keep Alive (19) (vbi:58)
19,
0, 20,
// Authentication Method (21) (vbi:66)
21,
0, 5, 'S', 'H', 'A', '-', '1',
// Authentication Data (22) (vbi:78)
22,
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
// Request Problem Info (23) (vbi:80)
23, 1,
// Will Delay Interval (24) (vbi:85)
24,
0, 0, 2, 88,
// Request Response Info (25) (vbi:87)
25, 1,
// Response Info (26) (vbi:98)
26,
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
// Server Reference (28) (vbi:108)
28,
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
// Reason String (31) (vbi:117)
31,
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
// Receive Maximum (33) (vbi:120)
33,
1, 244,
// Topic Alias Maximum (34) (vbi:123)
34,
3, 231,
// Topic Alias (35) (vbi:126)
35,
0, 3,
// Maximum Qos (36) (vbi:128)
36, 1,
// Retain Available (37) (vbi: 130)
37, 1,
// User Properties (38) (vbi:161)
38,
0, 5, 'h', 'e', 'l', 'l', 'o',
0, 6, 228, 184, 150, 231, 149, 140,
38,
0, 4, 'k', 'e', 'y', '2',
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
// Maximum Packet Size (39) (vbi:166)
39,
0, 0, 125, 0,
// Wildcard Subscriptions Available (40) (vbi:168)
40, 1,
// Subscription ID Available (41) (vbi:170)
41, 1,
// Shared Subscriptions Available (42) (vbi:172)
42, 1,
}
)
func init() {
validPacketProperties[PropPayloadFormat][Reserved] = 1
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
validPacketProperties[PropContentType][Reserved] = 1
validPacketProperties[PropResponseTopic][Reserved] = 1
validPacketProperties[PropCorrelationData][Reserved] = 1
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
validPacketProperties[PropAssignedClientID][Reserved] = 1
validPacketProperties[PropServerKeepAlive][Reserved] = 1
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
validPacketProperties[PropAuthenticationData][Reserved] = 1
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
validPacketProperties[PropWillDelayInterval][Reserved] = 1
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
validPacketProperties[PropResponseInfo][Reserved] = 1
validPacketProperties[PropServerReference][Reserved] = 1
validPacketProperties[PropReasonString][Reserved] = 1
validPacketProperties[PropReceiveMaximum][Reserved] = 1
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
validPacketProperties[PropTopicAlias][Reserved] = 1
validPacketProperties[PropMaximumQos][Reserved] = 1
validPacketProperties[PropRetainAvailable][Reserved] = 1
validPacketProperties[PropUser][Reserved] = 1
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
validPacketProperties[PropSubIDAvailable][Reserved] = 1
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
}
func TestEncodeProperties(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
require.Equal(t, propertiesBytes, b.Bytes())
}
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
}
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
props := propertiesStruct
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
require.NotEqual(t, propertiesBytes, b.Bytes())
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
}
func TestEncodePropertiesNil(t *testing.T) {
type tmp struct {
p *Properties
}
pr := tmp{}
b := bytes.NewBuffer([]byte{})
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
require.Equal(t, []byte{}, b.Bytes())
}
func TestEncodeZeroProperties(t *testing.T) {
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
props := new(Properties)
b := bytes.NewBuffer([]byte{})
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
require.Equal(t, []byte{0x00}, b.Bytes())
}
func TestDecodeProperties(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
props := new(Properties)
n, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 172+2, n)
require.EqualValues(t, propertiesStruct, *props)
}
func TestDecodePropertiesNil(t *testing.T) {
b := bytes.NewBuffer(propertiesBytes)
type tmp struct {
p *Properties
}
pr := tmp{}
n, err := pr.p.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, 0, n)
}
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
}
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
b := bytes.NewBuffer([]byte{0})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.NoError(t, err)
require.Equal(t, props, new(Properties))
}
func TestDecodePropertiesBadKeyByte(t *testing.T) {
b := bytes.NewBuffer([]byte{64, 1})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
}
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
b := bytes.NewBuffer([]byte{1, 99})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
}
func TestDecodePropertiesGeneralFailure(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestDecodePropertiesBadUserProps(t *testing.T) {
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
props := new(Properties)
_, err := props.Decode(Reserved, b)
require.Error(t, err)
}
func TestCopyProperties(t *testing.T) {
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
}
func TestCopyPropertiesNoTransfer(t *testing.T) {
pkA := propertiesStruct
pkB := pkA.Copy(false)
// Properties which should never be transferred from one connection to another
require.Equal(t, uint16(0), pkB.TopicAlias)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package packets
import (
"testing"
"github.com/stretchr/testify/require"
)
func encodeTestOK(wanted TPacketCase) bool {
if wanted.RawBytes == nil {
return false
}
if wanted.Group != "" && wanted.Group != "encode" {
return false
}
return true
}
func decodeTestOK(wanted TPacketCase) bool {
if wanted.Group != "" && wanted.Group != "decode" {
return false
}
return true
}
func TestTPacketCaseGet(t *testing.T) {
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,61 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-co
// SPDX-FileContributor: mochi-co
package system
import "sync/atomic"
// Info contains atomic counters and values for various server statistics
// commonly found in $SYS topics (and others).
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
type Info struct {
Version string `json:"version"` // the current version of the server
Started int64 `json:"started"` // the time the server started in unix seconds
Time int64 `json:"time"` // current time on the server
Uptime int64 `json:"uptime"` // the number of seconds the server has been online
BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started
BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started
ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients
ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected
ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
Retained int64 `json:"retained"` // total number of retained messages active on the broker
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker
PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received
PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
}
// Clone makes a copy of Info using atomic operation
func (i *Info) Clone() *Info {
return &Info{
Version: i.Version,
Started: atomic.LoadInt64(&i.Started),
Time: atomic.LoadInt64(&i.Time),
Uptime: atomic.LoadInt64(&i.Uptime),
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
BytesSent: atomic.LoadInt64(&i.BytesSent),
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
Retained: atomic.LoadInt64(&i.Retained),
Inflight: atomic.LoadInt64(&i.Inflight),
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
Threads: atomic.LoadInt64(&i.Threads),
}
}

View File

@ -0,0 +1,37 @@
package system
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestClone(t *testing.T) {
o := &Info{
Version: "version",
Started: 1,
Time: 2,
Uptime: 3,
BytesReceived: 4,
BytesSent: 5,
ClientsConnected: 6,
ClientsMaximum: 7,
ClientsTotal: 8,
ClientsDisconnected: 9,
MessagesReceived: 10,
MessagesSent: 11,
MessagesDropped: 20,
Retained: 12,
Inflight: 13,
InflightDropped: 14,
Subscriptions: 15,
PacketsReceived: 16,
PacketsSent: 17,
MemoryAlloc: 18,
Threads: 19,
}
n := o.Clone()
require.Equal(t, o, n)
}

View File

@ -0,0 +1,699 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"strings"
"sync"
"sync/atomic"
"github.com/mochi-co/mqtt/v2/packets"
)
var (
SharePrefix = "$SHARE" // the prefix indicating a share topic
SysPrefix = "$SYS" // the prefix indicating a system info topic
)
// TopicAliases contains inbound and outbound topic alias registrations.
type TopicAliases struct {
Inbound *InboundTopicAliases
Outbound *OutboundTopicAliases
}
// NewTopicAliases returns an instance of TopicAliases.
func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
return TopicAliases{
Inbound: NewInboundTopicAliases(topicAliasMaximum),
Outbound: NewOutboundTopicAliases(topicAliasMaximum),
}
}
// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
return &InboundTopicAliases{
maximum: topicAliasMaximum,
internal: map[uint16]string{},
}
}
// InboundTopicAliases contains a map of topic aliases received from the client.
type InboundTopicAliases struct {
internal map[uint16]string
sync.RWMutex
maximum uint16
}
// Set sets a new alias for a specific topic.
func (a *InboundTopicAliases) Set(id uint16, topic string) string {
a.Lock()
defer a.Unlock()
if a.maximum == 0 {
return topic // ?
}
if existing, ok := a.internal[id]; ok && topic == "" {
return existing
}
a.internal[id] = topic
return topic
}
// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
type OutboundTopicAliases struct {
internal map[string]uint16
sync.RWMutex
cursor uint32
maximum uint16
}
// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
return &OutboundTopicAliases{
maximum: topicAliasMaximum,
internal: map[string]uint16{},
}
}
// Set sets a new topic alias for a topic and returns the alias value, and a boolean
// indicating if the alias already existed.
func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
a.Lock()
defer a.Unlock()
if a.maximum == 0 {
return 0, false
}
if i, ok := a.internal[topic]; ok {
return i, true
}
i := atomic.LoadUint32(&a.cursor)
if i+1 > uint32(a.maximum) {
// if i+1 > math.MaxUint16 {
return 0, false
}
a.internal[topic] = uint16(i) + 1
atomic.StoreUint32(&a.cursor, i+1)
return uint16(i) + 1, false
}
// SharedSubscriptions contains a map of subscriptions to a shared filter,
// keyed on share group then client id.
type SharedSubscriptions struct {
internal map[string]map[string]packets.Subscription
sync.RWMutex
}
// NewSharedSubscriptions returns a new instance of Subscriptions.
func NewSharedSubscriptions() *SharedSubscriptions {
return &SharedSubscriptions{
internal: map[string]map[string]packets.Subscription{},
}
}
// Add creates a new shared subscription for a group and client id pair.
func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
s.Lock()
defer s.Unlock()
if _, ok := s.internal[group]; !ok {
s.internal[group] = map[string]packets.Subscription{}
}
s.internal[group][id] = val
}
// Delete deletes a client id from a shared subscription group.
func (s *SharedSubscriptions) Delete(group, id string) {
s.Lock()
defer s.Unlock()
delete(s.internal[group], id)
if len(s.internal[group]) == 0 {
delete(s.internal, group)
}
}
// Get returns the subscription properties for a client id in a share group, if one exists.
func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
s.RLock()
defer s.RUnlock()
if _, ok := s.internal[group]; !ok {
return val, ok
}
val, ok = s.internal[group][id]
return val, ok
}
// GroupLen returns the number of groups subscribed to the filter.
func (s *SharedSubscriptions) GroupLen() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Len returns the total number of shared subscriptions to a filter across all groups.
func (s *SharedSubscriptions) Len() int {
s.RLock()
defer s.RUnlock()
n := 0
for _, group := range s.internal {
n += len(group)
}
return n
}
// GetAll returns all shared subscription groups and their subscriptions.
func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
s.RLock()
defer s.RUnlock()
m := map[string]map[string]packets.Subscription{}
for group, subs := range s.internal {
if _, ok := m[group]; !ok {
m[group] = map[string]packets.Subscription{}
}
for id, sub := range subs {
m[group][id] = sub
}
}
return m
}
// Subscriptions is a map of subscriptions keyed on client.
type Subscriptions struct {
internal map[string]packets.Subscription
sync.RWMutex
}
// NewSubscriptions returns a new instance of Subscriptions.
func NewSubscriptions() *Subscriptions {
return &Subscriptions{
internal: map[string]packets.Subscription{},
}
}
// Add adds a new subscription for a client. ID can be a filter in the
// case this map is client state, or a client id if particle state.
func (s *Subscriptions) Add(id string, val packets.Subscription) {
s.Lock()
defer s.Unlock()
s.internal[id] = val
}
// GetAll returns all subscriptions.
func (s *Subscriptions) GetAll() map[string]packets.Subscription {
s.RLock()
defer s.RUnlock()
m := map[string]packets.Subscription{}
for k, v := range s.internal {
m[k] = v
}
return m
}
// Get returns a subscriptions for a specific client or filter id.
func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
s.RLock()
defer s.RUnlock()
val, ok = s.internal[id]
return val, ok
}
// Len returns the number of subscriptions.
func (s *Subscriptions) Len() int {
s.RLock()
defer s.RUnlock()
val := len(s.internal)
return val
}
// Delete removes a subscription by client or filter id.
func (s *Subscriptions) Delete(id string) {
s.Lock()
defer s.Unlock()
delete(s.internal, id)
}
// ClientSubscriptions is a map of aggregated subscriptions for a client.
type ClientSubscriptions map[string]packets.Subscription
// Subscribers contains the shared and non-shared subscribers matching a topic.
type Subscribers struct {
Shared map[string]map[string]packets.Subscription
SharedSelected map[string]packets.Subscription
Subscriptions map[string]packets.Subscription
}
// SelectShared returns one subscriber for each shared subscription group.
func (s *Subscribers) SelectShared() {
s.SharedSelected = map[string]packets.Subscription{}
for _, subs := range s.Shared {
for client, sub := range subs {
cls, ok := s.SharedSelected[client]
if !ok {
cls = sub
}
s.SharedSelected[client] = cls.Merge(sub)
break
}
}
}
// MergeSharedSelected merges the selected subscribers for a shared subscription group
// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
// due to have both types of subscription matching the same filter.
func (s *Subscribers) MergeSharedSelected() {
for client, sub := range s.SharedSelected {
cls, ok := s.Subscriptions[client]
if !ok {
cls = sub
}
s.Subscriptions[client] = cls.Merge(sub)
}
}
// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
type TopicsIndex struct {
Retained *packets.Packets
root *particle // a leaf containing a message and more leaves.
}
// NewTopicsIndex returns a pointer to a new instance of Index.
func NewTopicsIndex() *TopicsIndex {
return &TopicsIndex{
Retained: packets.NewPackets(),
root: &particle{
particles: newParticles(),
subscriptions: NewSubscriptions(),
},
}
}
// Subscribe adds a new subscription for a client to a topic filter, returning
// true if the subscription was new.
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
x.root.Lock()
defer x.root.Unlock()
var existed bool
prefix, _ := isolateParticle(subscription.Filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
group, _ := isolateParticle(subscription.Filter, 1)
n := x.set(subscription.Filter, 2)
_, existed = n.shared.Get(group, client)
n.shared.Add(group, client, subscription)
} else {
n := x.set(subscription.Filter, 0)
_, existed = n.subscriptions.Get(client)
n.subscriptions.Add(client, subscription)
}
return !existed
}
// Unsubscribe removes a subscription filter for a client, returning true if the
// subscription existed.
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
x.root.Lock()
defer x.root.Unlock()
var d int
if strings.HasPrefix(filter, SharePrefix) {
d = 2
}
particle := x.seek(filter, d)
if particle == nil {
return false
}
prefix, _ := isolateParticle(filter, 0)
if strings.EqualFold(prefix, SharePrefix) {
group, _ := isolateParticle(filter, 1)
particle.shared.Delete(group, client)
} else {
particle.subscriptions.Delete(client)
}
x.trim(particle)
return true
}
// RetainMessage saves a message payload to the end of a topic address. Returns
// 1 if a retained message was added, and -1 if the retained message was removed.
// 0 is returned if sequential empty payloads are received.
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
x.root.Lock()
defer x.root.Unlock()
n := x.set(pk.TopicName, 0)
n.Lock()
defer n.Unlock()
if len(pk.Payload) > 0 {
n.retainPath = pk.TopicName
x.Retained.Add(pk.TopicName, pk)
return 1
}
var out int64
if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
out = -1 // if a retained packet existed, return -1
}
n.retainPath = ""
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
x.trim(n)
return out
}
// set creates a topic address in the index and returns the final particle.
func (x *TopicsIndex) set(topic string, d int) *particle {
var key string
var hasNext = true
n := x.root
for hasNext {
key, hasNext = isolateParticle(topic, d)
d++
p := n.particles.get(key)
if p == nil {
p = newParticle(key, n)
n.particles.add(p)
}
n = p
}
return n
}
// seek finds the particle at a specific index in a topic filter.
func (x *TopicsIndex) seek(filter string, d int) *particle {
var key string
var hasNext = true
n := x.root
for hasNext {
key, hasNext = isolateParticle(filter, d)
n = n.particles.get(key)
d++
if n == nil {
return nil
}
}
return n
}
// trim removes empty filter particles from the index.
func (x *TopicsIndex) trim(n *particle) {
for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len() == 0 {
key := n.key
n = n.parent
n.particles.delete(key)
}
}
// Messages returns a slice of any retained messages which match a filter.
func (x *TopicsIndex) Messages(filter string) []packets.Packet {
return x.scanMessages(filter, 0, nil, []packets.Packet{})
}
// scanMessages returns all retained messages on topics matching a given filter.
func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
if n == nil {
n = x.root
}
if len(filter) == 0 || x.Retained.Len() == 0 {
return pks
}
if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
if pk, ok := x.Retained.Get(filter); ok {
pks = append(pks, pk)
}
return pks
}
key, hasNext := isolateParticle(filter, d)
if key == "+" || key == "#" || d == -1 {
for _, adjacent := range n.particles.getAll() {
if d == 0 && adjacent.key == SysPrefix {
continue
}
if !hasNext {
if adjacent.retainPath != "" {
if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
pks = append(pks, pk)
}
}
}
if hasNext || (d >= 0 && key == "#") {
pks = x.scanMessages(filter, d+1, adjacent, pks)
}
}
return pks
}
if particle := n.particles.get(key); particle != nil {
if hasNext {
return x.scanMessages(filter, d+1, particle, pks)
}
if pk, ok := x.Retained.Get(particle.retainPath); ok {
pks = append(pks, pk)
}
}
return pks
}
// Subscribers returns a map of clients who are subscribed to matching filters,
// their subscription ids and highest qos.
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
return x.scanSubscribers(topic, 0, nil, &Subscribers{
Shared: map[string]map[string]packets.Subscription{},
SharedSelected: map[string]packets.Subscription{},
Subscriptions: map[string]packets.Subscription{},
})
}
// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
if n == nil {
n = x.root
}
if len(topic) == 0 {
return subs
}
key, hasNext := isolateParticle(topic, d)
for _, partKey := range []string{key, "+", "#"} {
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
x.gatherSubscriptions(topic, particle, subs)
x.gatherSharedSubscriptions(particle, subs)
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
}
if hasNext {
x.scanSubscribers(topic, d+1, particle, subs)
}
}
}
return subs
}
// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
if subs.Subscriptions == nil {
subs.Subscriptions = map[string]packets.Subscription{}
}
for client, sub := range particle.subscriptions.GetAll() {
if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
continue
}
cls, ok := subs.Subscriptions[client]
if !ok {
cls = sub
}
subs.Subscriptions[client] = cls.Merge(sub)
}
}
// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
if subs.Shared == nil {
subs.Shared = map[string]map[string]packets.Subscription{}
}
for _, shares := range particle.shared.GetAll() {
for client, sub := range shares {
if _, ok := subs.Shared[sub.Filter]; !ok {
subs.Shared[sub.Filter] = map[string]packets.Subscription{}
}
subs.Shared[sub.Filter][client] = sub
}
}
}
// isolateParticle extracts a particle between d / and d+1 / without allocations.
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
var next, end int
for i := 0; end > -1 && i <= d; i++ {
end = strings.IndexRune(filter, '/')
switch {
case d > -1 && i == d && end > -1:
hasNext = true
particle = filter[next:end]
case end > -1:
hasNext = false
filter = filter[end+1:]
default:
hasNext = false
particle = filter[next:]
}
}
return
}
// IsSharedFilter returns true if the filter uses the share prefix.
func IsSharedFilter(filter string) bool {
prefix, _ := isolateParticle(filter, 0)
return strings.EqualFold(prefix, SharePrefix)
}
// IsValidFilter returns true if the filter is valid.
func IsValidFilter(filter string, forPublish bool) bool {
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
return false // [MQTT-4.7.3-1]
}
if forPublish {
if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
// 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
return false
}
if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
return false //[MQTT-3.3.2-2]
}
}
wildhash := strings.IndexRune(filter, '#')
if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
return false
}
prefix, hasNext := isolateParticle(filter, 0)
if !hasNext && strings.EqualFold(prefix, SharePrefix) {
return false // [MQTT-4.8.2-1]
}
if hasNext && strings.EqualFold(prefix, SharePrefix) {
group, hasNext := isolateParticle(filter, 1)
if !hasNext {
return false // [MQTT-4.8.2-1]
}
if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
return false // [MQTT-4.8.2-2]
}
}
return true
}
// particle is a child node on the tree.
type particle struct {
key string // the key of the particle
parent *particle // a pointer to the parent of the particle
particles particles // a map of child particles
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
retainPath string // path of a retained message
sync.Mutex // mutex for when making changes to the particle
}
// newParticle returns a pointer to a new instance of particle.
func newParticle(key string, parent *particle) *particle {
return &particle{
key: key,
parent: parent,
particles: newParticles(),
subscriptions: NewSubscriptions(),
shared: NewSharedSubscriptions(),
}
}
// particles is a concurrency safe map of particles.
type particles struct {
internal map[string]*particle
sync.RWMutex
}
// newParticles returns a map of particles.
func newParticles() particles {
return particles{
internal: map[string]*particle{},
}
}
// add adds a new particle.
func (p *particles) add(val *particle) {
p.Lock()
p.internal[val.key] = val
p.Unlock()
}
// getAll returns all particles.
func (p *particles) getAll() map[string]*particle {
p.RLock()
defer p.RUnlock()
m := map[string]*particle{}
for k, v := range p.internal {
m[k] = v
}
return m
}
// get returns a particle by id (key).
func (p *particles) get(id string) *particle {
p.RLock()
defer p.RUnlock()
return p.internal[id]
}
// len returns the number of particles.
func (p *particles) len() int {
p.RLock()
defer p.RUnlock()
val := len(p.internal)
return val
}
// delete removes a particle.
func (p *particles) delete(id string) {
p.Lock()
defer p.Unlock()
delete(p.internal, id)
}

View File

@ -0,0 +1,842 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"testing"
"github.com/mochi-co/mqtt/v2/packets"
"github.com/stretchr/testify/require"
)
const (
testGroup = "testgroup"
otherGroup = "other"
)
func TestNewSharedSubscriptions(t *testing.T) {
s := NewSharedSubscriptions()
require.NotNil(t, s.internal)
}
func TestSharedSubscriptionsAdd(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
}
func TestSharedSubscriptionsGet(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 2})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
sub, ok := s.Get(testGroup, "cl2")
require.Equal(t, true, ok)
require.Equal(t, byte(2), sub.Qos)
}
func TestSharedSubscriptionsGetAll(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
require.Contains(t, s.internal, otherGroup)
require.Contains(t, s.internal[otherGroup], "cl3")
subs := s.GetAll()
require.Len(t, subs, 2)
require.Len(t, subs[testGroup], 2)
require.Len(t, subs[otherGroup], 1)
}
func TestSharedSubscriptionsLen(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
require.Contains(t, s.internal, otherGroup)
require.Contains(t, s.internal[otherGroup], "cl2")
require.Equal(t, 3, s.Len())
require.Equal(t, 2, s.GroupLen())
}
func TestSharedSubscriptionsDelete(t *testing.T) {
s := NewSharedSubscriptions()
s.Add(testGroup, "cl1", packets.Subscription{Qos: 1})
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl1")
require.Contains(t, s.internal, testGroup)
require.Contains(t, s.internal[testGroup], "cl2")
require.Equal(t, 2, s.Len())
s.Delete(testGroup, "cl1")
_, ok := s.Get(testGroup, "cl1")
require.False(t, ok)
require.Equal(t, 1, s.GroupLen())
require.Equal(t, 1, s.Len())
s.Delete(testGroup, "cl2")
_, ok = s.Get(testGroup, "cl2")
require.False(t, ok)
require.Equal(t, 0, s.GroupLen())
require.Equal(t, 0, s.Len())
}
func TestNewSubscriptions(t *testing.T) {
s := NewSubscriptions()
require.NotNil(t, s.internal)
}
func TestSubscriptionsAdd(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{})
require.Contains(t, s.internal, "cl1")
}
func TestSubscriptionsGet(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 2})
s.Add("cl2", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
sub, ok := s.Get("cl1")
require.True(t, ok)
require.Equal(t, byte(2), sub.Qos)
}
func TestSubscriptionsGetAll(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 0})
s.Add("cl2", packets.Subscription{Qos: 1})
s.Add("cl3", packets.Subscription{Qos: 2})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Contains(t, s.internal, "cl3")
subs := s.GetAll()
require.Len(t, subs, 3)
}
func TestSubscriptionsLen(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 0})
s.Add("cl2", packets.Subscription{Qos: 1})
require.Contains(t, s.internal, "cl1")
require.Contains(t, s.internal, "cl2")
require.Equal(t, 2, s.Len())
}
func TestSubscriptionsDelete(t *testing.T) {
s := NewSubscriptions()
s.Add("cl1", packets.Subscription{Qos: 1})
require.Contains(t, s.internal, "cl1")
s.Delete("cl1")
_, ok := s.Get("cl1")
require.False(t, ok)
}
func TestNewTopicsIndex(t *testing.T) {
index := NewTopicsIndex()
require.NotNil(t, index)
require.NotNil(t, index.root)
}
func BenchmarkNewTopicsIndex(b *testing.B) {
for n := 0; n < b.N; n++ {
NewTopicsIndex()
}
}
func TestSubscribe(t *testing.T) {
tt := []struct {
desc string
client string
filter string
subscription packets.Subscription
wasNew bool
}{
{
desc: "subscribe",
client: "cl1",
subscription: packets.Subscription{Filter: "a/b/c", Qos: 2},
wasNew: true,
},
{
desc: "subscribe existed",
client: "cl1",
subscription: packets.Subscription{Filter: "a/b/c", Qos: 1},
wasNew: false,
},
{
desc: "subscribe case sensitive didnt exist",
client: "cl1",
subscription: packets.Subscription{Filter: "A/B/c", Qos: 1},
wasNew: true,
},
{
desc: "wildcard+ sub",
client: "cl1",
subscription: packets.Subscription{Filter: "d/+"},
wasNew: true,
},
{
desc: "wildcard# sub",
client: "cl1",
subscription: packets.Subscription{Filter: "d/e/#"},
wasNew: true,
},
}
index := NewTopicsIndex()
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription))
})
}
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
client, exists := final.subscriptions.Get("cl1")
require.True(t, exists)
require.Equal(t, byte(1), client.Qos)
}
func TestSubscribeShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2})
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
client, exists := final.shared.Get("tmp", "cl1")
require.True(t, exists)
require.Equal(t, byte(2), client.Qos)
require.Equal(t, 0, final.subscriptions.Len())
require.Equal(t, 1, final.shared.Len())
}
func BenchmarkSubscribe(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"})
}
}
func BenchmarkSubscribeShared(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"})
}
}
func TestUnsubscribe(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1})
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1})
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1})
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1})
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
require.NotNil(t, client)
require.True(t, exists)
index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2})
client, exists = index.root.particles.get("#").subscriptions.Get("cl3")
require.NotNil(t, client)
require.True(t, exists)
ok := index.Unsubscribe("a/b/c/d", "cl1")
require.True(t, ok)
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
ok = index.Unsubscribe("d/e/f", "cl1")
require.True(t, ok)
require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len())
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
require.NotNil(t, client)
require.True(t, exists)
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody")
require.False(t, ok)
}
func TestUnsubscribeNoCascade(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"})
ok := index.Unsubscribe("a/b/c/e/e", "cl1")
require.True(t, ok)
require.Equal(t, 1, index.root.particles.len())
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1")
require.NotNil(t, client)
require.True(t, exists)
}
func TestUnsubscribeShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2})
final := index.root.particles.get("a").particles.get("b").particles.get("c")
require.NotNil(t, final)
client, exists := final.shared.Get("tmp", "cl1")
require.True(t, exists)
require.Equal(t, byte(2), client.Qos)
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
_, exists = final.shared.Get("tmp", "cl1")
require.False(t, exists)
}
func BenchmarkUnsubscribe(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
b.StopTimer()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
b.StartTimer()
index.Unsubscribe("a/b/c", "cl1")
}
}
func TestIndexSeek(t *testing.T) {
filter := "a/b/c/d/e/f"
index := NewTopicsIndex()
k1 := index.set(filter, 0)
require.Equal(t, "f", k1.key)
k1.subscriptions.Add("cl1", packets.Subscription{})
require.Equal(t, k1, index.seek(filter, 0))
require.Nil(t, index.seek("d/e/f", 0))
}
func TestIndexTrim(t *testing.T) {
index := NewTopicsIndex()
k1 := index.set("a/b/c", 0)
require.Equal(t, "c", k1.key)
k1.subscriptions.Add("cl1", packets.Subscription{})
k2 := index.set("a/b/c/d/e/f", 0)
require.Equal(t, "f", k2.key)
k2.subscriptions.Add("cl1", packets.Subscription{})
k3 := index.set("a/b", 0)
require.Equal(t, "b", k3.key)
k3.subscriptions.Add("cl1", packets.Subscription{})
index.trim(k2)
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f"))
require.NotNil(t, index.root.particles.get("a").particles.get("b"))
k2.subscriptions.Delete("cl1")
index.trim(k2)
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d"))
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
k1.subscriptions.Delete("cl1")
k3.subscriptions.Delete("cl1")
index.trim(k2)
require.Nil(t, index.root.particles.get("a"))
}
func TestIndexSet(t *testing.T) {
index := NewTopicsIndex()
child := index.set("a/b/c", 0)
require.Equal(t, "c", child.key)
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
child = index.set("a/b/c/d/e", 0)
require.Equal(t, "e", child.key)
child = index.set("a/b/c/c/a", 0)
require.Equal(t, "a", child.key)
}
func TestIndexSetPrefixed(t *testing.T) {
index := NewTopicsIndex()
child := index.set("/c", 0)
require.Equal(t, "c", child.key)
require.NotNil(t, index.root.particles.get("").particles.get("c"))
}
func BenchmarkIndexSet(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.set("a/b/c", 0)
}
}
func TestRetainMessage(t *testing.T) {
pk := packets.Packet{
FixedHeader: packets.FixedHeader{Retain: true},
TopicName: "a/b/c",
Payload: []byte("hello"),
}
index := NewTopicsIndex()
r := index.RetainMessage(pk)
require.Equal(t, int64(1), r)
pke, ok := index.Retained.Get(pk.TopicName)
require.True(t, ok)
require.Equal(t, pk, pke)
pk2 := packets.Packet{
FixedHeader: packets.FixedHeader{Retain: true},
TopicName: "a/b/d/f",
Payload: []byte("hello"),
}
r = index.RetainMessage(pk2)
require.Equal(t, int64(1), r)
// The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message.
r = index.RetainMessage(pk2)
require.Equal(t, int64(1), r)
// Clear existing retained
pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}}
r = index.RetainMessage(pk3)
require.Equal(t, int64(-1), r)
_, ok = index.Retained.Get(pk.TopicName)
require.False(t, ok)
// Clear no retained
r = index.RetainMessage(pk3)
require.Equal(t, int64(0), r)
}
func BenchmarkRetainMessage(b *testing.B) {
index := NewTopicsIndex()
for n := 0; n < b.N; n++ {
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
}
}
func TestIsolateParticle(t *testing.T) {
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
require.Equal(t, "to", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
require.Equal(t, "my", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
require.Equal(t, "mqtt", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("/path/", 0)
require.Equal(t, "", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 1)
require.Equal(t, "path", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("/path/", 2)
require.Equal(t, "", particle)
require.Equal(t, false, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
require.Equal(t, "+", particle)
require.Equal(t, true, hasNext)
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
require.Equal(t, "+", particle)
require.Equal(t, false, hasNext)
}
func BenchmarkIsolateParticle(b *testing.B) {
for n := 0; n < b.N; n++ {
isolateParticle("path/to/my/mqtt", 3)
}
}
func TestScanSubscribers(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22})
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"})
index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"})
index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"})
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77})
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237})
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3})
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234})
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 4, len(subs.Subscriptions))
require.Contains(t, subs.Subscriptions, "cl1")
require.Contains(t, subs.Subscriptions, "cl2")
require.Contains(t, subs.Subscriptions, "cl3")
require.Contains(t, subs.Subscriptions, "cl4")
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
require.Equal(t, 0, len(subs.Subscriptions))
}
func TestScanSubscribersShared(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 3, len(subs.Shared))
}
func TestSelectSharedSubscriber(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110})
index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
require.Equal(t, 2, len(subs.Shared))
require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c")
require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c")
require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3)
require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1)
subs.SelectShared()
require.Len(t, subs.SharedSelected, 2)
}
func TestMergeSharedSelected(t *testing.T) {
s := &Subscribers{
SharedSelected: map[string]packets.Subscription{
"cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110},
"cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111},
},
Subscriptions: map[string]packets.Subscription{
"cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112},
},
}
s.MergeSharedSelected()
require.Equal(t, 2, len(s.Subscriptions))
require.Contains(t, s.Subscriptions, "cl1")
require.Contains(t, s.Subscriptions, "cl2")
require.EqualValues(t, map[string]int{
SharePrefix + "/tmp2/a/b/c": 111,
"a/b/c": 112,
}, s.Subscriptions["cl2"].Identifiers)
}
func TestSubscribersFind(t *testing.T) {
tt := []struct {
filter string
topic string
matched bool
}{
{filter: "a", topic: "a", matched: true},
{filter: "a/", topic: "a", matched: false},
{filter: "a/", topic: "a/", matched: true},
{filter: "/a", topic: "/a", matched: true},
{filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true},
{filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
{filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
{filter: "#", topic: "path/to/my/mqtt", matched: true},
{filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true},
{filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true},
{filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2
{filter: "trailing-end/#", topic: "trailing-end/", matched: true},
{filter: "+/prefixed", topic: "/prefixed", matched: true},
{filter: "+/+/#", topic: "path/to/my/mqtt", matched: true},
{filter: "path/to/", topic: "path/to/my/mqtt", matched: false},
{filter: "#/stuff", topic: "path/to/my/mqtt", matched: false},
{filter: "#", topic: "$SYS/info", matched: false},
{filter: "$SYS/#", topic: "$SYS/info", matched: true},
{filter: "+/info", topic: "$SYS/info", matched: false},
}
for _, tx := range tt {
t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: tx.filter})
subs := index.Subscribers(tx.topic)
require.Equal(t, tx.matched, len(subs.Subscriptions) == 1)
})
}
}
func BenchmarkSubscribers(b *testing.B) {
index := NewTopicsIndex()
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"})
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"})
index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"})
index.Subscribe("cl3", packets.Subscription{Filter: "#"})
for n := 0; n < b.N; n++ {
index.Subscribers("a/b/c")
}
}
func TestMessagesPattern(t *testing.T) {
payload := []byte("hello")
fh := packets.FixedHeader{Type: packets.Publish, Retain: true}
pks := []packets.Packet{
{TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh},
{TopicName: "$SYS/info", Payload: payload, FixedHeader: fh},
{TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh},
{TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh},
{TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh},
{TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh},
{TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh},
{TopicName: "asdf", Payload: payload, FixedHeader: fh},
}
tt := []struct {
filter string
len int
}{
{"a/b/c/d", 1},
{"$SYS/+", 2},
{"$SYS/#", 2},
{"#", len(pks) - 2},
{"a/b/c/+", 2},
{"a/+/c/+", 2},
{"+/+/+/d", 1},
{"q/w/e/#", 1},
{"+/+/+/+", 3},
{"q/#", 2},
{"asdf", 1},
{"", 0},
{"#", 6},
}
index := NewTopicsIndex()
for _, pk := range pks {
index.RetainMessage(pk)
}
for _, tx := range tt {
t.Run("filter:'"+tx.filter, func(t *testing.T) {
messages := index.Messages(tx.filter)
require.Equal(t, tx.len, len(messages))
})
}
}
func BenchmarkMessages(b *testing.B) {
index := NewTopicsIndex()
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"})
index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"})
index.RetainMessage(packets.Packet{TopicName: "$SYS/info"})
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
for n := 0; n < b.N; n++ {
index.Messages("+/b/c/+")
}
}
func TestNewParticles(t *testing.T) {
cl := newParticles()
require.NotNil(t, cl.internal)
}
func TestParticlesAdd(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
require.Contains(t, p.internal, "a")
}
func TestParticlesGet(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
p.add(&particle{key: "b"})
require.Contains(t, p.internal, "a")
require.Contains(t, p.internal, "b")
particle := p.get("a")
require.NotNil(t, particle)
require.Equal(t, "a", particle.key)
}
func TestParticlesGetAll(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
p.add(&particle{key: "b"})
p.add(&particle{key: "c"})
require.Contains(t, p.internal, "a")
require.Contains(t, p.internal, "b")
require.Contains(t, p.internal, "c")
particles := p.getAll()
require.Len(t, particles, 3)
}
func TestParticlesLen(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
p.add(&particle{key: "b"})
require.Contains(t, p.internal, "a")
require.Contains(t, p.internal, "b")
require.Equal(t, 2, p.len())
}
func TestParticlesDelete(t *testing.T) {
p := newParticles()
p.add(&particle{key: "a"})
require.Contains(t, p.internal, "a")
p.delete("a")
particle := p.get("a")
require.Nil(t, particle)
}
func TestIsValid(t *testing.T) {
require.True(t, IsValidFilter("a/b/c", false))
require.True(t, IsValidFilter("a/b//c", false))
require.True(t, IsValidFilter("$SYS", false))
require.True(t, IsValidFilter("$SYS/info", false))
require.True(t, IsValidFilter("$sys/info", false))
require.True(t, IsValidFilter("abc/#", false))
require.False(t, IsValidFilter("", false))
require.False(t, IsValidFilter(SharePrefix, false))
require.False(t, IsValidFilter(SharePrefix+"/", false))
require.False(t, IsValidFilter(SharePrefix+"/b+/", false))
require.False(t, IsValidFilter(SharePrefix+"/+", false))
require.False(t, IsValidFilter(SharePrefix+"/#", false))
require.False(t, IsValidFilter(SharePrefix+"/#/", false))
require.False(t, IsValidFilter("a/#/c", false))
}
func TestIsValidForPublish(t *testing.T) {
require.True(t, IsValidFilter("", true))
require.True(t, IsValidFilter("a/b/c", true))
require.False(t, IsValidFilter("a/b/+/d", true))
require.False(t, IsValidFilter("a/b/#", true))
require.False(t, IsValidFilter("$SYS/info", true))
}
func TestIsSharedFilter(t *testing.T) {
require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c"))
require.False(t, IsSharedFilter("a/b/c"))
}
func TestNewInboundAliases(t *testing.T) {
a := NewInboundTopicAliases(5)
require.NotNil(t, a)
require.NotNil(t, a.internal)
require.Equal(t, uint16(5), a.maximum)
}
func TestInboundAliasesSet(t *testing.T) {
topic := "test"
id := uint16(1)
a := NewInboundTopicAliases(5)
require.Equal(t, topic, a.Set(id, topic))
require.Contains(t, a.internal, id)
require.Equal(t, a.internal[id], topic)
require.Equal(t, topic, a.Set(id, ""))
}
func TestInboundAliasesSetMaxZero(t *testing.T) {
topic := "test"
id := uint16(1)
a := NewInboundTopicAliases(0)
require.Equal(t, topic, a.Set(id, topic))
require.NotContains(t, a.internal, id)
}
func TestNewOutboundAliases(t *testing.T) {
a := NewOutboundTopicAliases(5)
require.NotNil(t, a)
require.NotNil(t, a.internal)
require.Equal(t, uint16(5), a.maximum)
require.Equal(t, uint32(0), a.cursor)
}
func TestOutboundAliasesSet(t *testing.T) {
a := NewOutboundTopicAliases(3)
n, ok := a.Set("t1")
require.False(t, ok)
require.Equal(t, uint16(1), n)
n, ok = a.Set("t2")
require.False(t, ok)
require.Equal(t, uint16(2), n)
n, ok = a.Set("t3")
require.False(t, ok)
require.Equal(t, uint16(3), n)
n, ok = a.Set("t4")
require.False(t, ok)
require.Equal(t, uint16(0), n)
n, ok = a.Set("t2")
require.True(t, ok)
require.Equal(t, uint16(2), n)
}
func TestOutboundAliasesSetMaxZero(t *testing.T) {
topic := "test"
a := NewOutboundTopicAliases(0)
n, ok := a.Set(topic)
require.False(t, ok)
require.Equal(t, uint16(0), n)
}
func TestNewTopicAliases(t *testing.T) {
a := NewTopicAliases(5)
require.NotNil(t, a.Inbound)
require.Equal(t, uint16(5), a.Inbound.maximum)
require.NotNil(t, a.Outbound)
require.Equal(t, uint16(5), a.Outbound.maximum)
}

View File

@ -0,0 +1 @@
language: go

View File

@ -0,0 +1,35 @@
bbloom.go
// The MIT License (MIT)
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
siphash.go
// https://github.com/dchest/siphash
//
// Written in 2012 by Dmitry Chestnykh.
//
// To the extent possible under law, the author have dedicated all copyright
// and related and neighboring rights to this software to the public domain
// worldwide. This software is distributed without any warranty.
// http://creativecommons.org/publicdomain/zero/1.0/
//
// Package siphash implements SipHash-2-4, a fast short-input PRF
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.

View File

@ -0,0 +1,131 @@
## bbloom: a bitset Bloom filter for go/golang
===
[![Build Status](https://travis-ci.org/AndreasBriese/bbloom.png?branch=master)](http://travis-ci.org/AndreasBriese/bbloom)
package implements a fast bloom filter with real 'bitset' and JSONMarshal/JSONUnmarshal to store/reload the Bloom filter.
NOTE: the package uses unsafe.Pointer to set and read the bits from the bitset. If you're uncomfortable with using the unsafe package, please consider using my bloom filter package at github.com/AndreasBriese/bloom
===
changelog 11/2015: new thread safe methods AddTS(), HasTS(), AddIfNotHasTS() following a suggestion from Srdjan Marinovic (github @a-little-srdjan), who used this to code a bloomfilter cache.
This bloom filter was developed to strengthen a website-log database and was tested and optimized for this log-entry mask: "2014/%02i/%02i %02i:%02i:%02i /info.html".
Nonetheless bbloom should work with any other form of entries.
~~Hash function is a modified Berkeley DB sdbm hash (to optimize for smaller strings). sdbm http://www.cse.yorku.ca/~oz/hash.html~~
Found sipHash (SipHash-2-4, a fast short-input PRF created by Jean-Philippe Aumasson and Daniel J. Bernstein.) to be about as fast. sipHash had been ported by Dimtry Chestnyk to Go (github.com/dchest/siphash )
Minimum hashset size is: 512 ([4]uint64; will be set automatically).
###install
```sh
go get github.com/AndreasBriese/bbloom
```
###test
+ change to folder ../bbloom
+ create wordlist in file "words.txt" (you might use `python permut.py`)
+ run 'go test -bench=.' within the folder
```go
go test -bench=.
```
~~If you've installed the GOCONVEY TDD-framework http://goconvey.co/ you can run the tests automatically.~~
using go's testing framework now (have in mind that the op timing is related to 65536 operations of Add, Has, AddIfNotHas respectively)
### usage
after installation add
```go
import (
...
"github.com/AndreasBriese/bbloom"
...
)
```
at your header. In the program use
```go
// create a bloom filter for 65536 items and 1 % wrong-positive ratio
bf := bbloom.New(float64(1<<16), float64(0.01))
// or
// create a bloom filter with 650000 for 65536 items and 7 locs per hash explicitly
// bf = bbloom.New(float64(650000), float64(7))
// or
bf = bbloom.New(650000.0, 7.0)
// add one item
bf.Add([]byte("butter"))
// Number of elements added is exposed now
// Note: ElemNum will not be included in JSON export (for compatability to older version)
nOfElementsInFilter := bf.ElemNum
// check if item is in the filter
isIn := bf.Has([]byte("butter")) // should be true
isNotIn := bf.Has([]byte("Butter")) // should be false
// 'add only if item is new' to the bloomfilter
added := bf.AddIfNotHas([]byte("butter")) // should be false because 'butter' is already in the set
added = bf.AddIfNotHas([]byte("buTTer")) // should be true because 'buTTer' is new
// thread safe versions for concurrent use: AddTS, HasTS, AddIfNotHasTS
// add one item
bf.AddTS([]byte("peanutbutter"))
// check if item is in the filter
isIn = bf.HasTS([]byte("peanutbutter")) // should be true
isNotIn = bf.HasTS([]byte("peanutButter")) // should be false
// 'add only if item is new' to the bloomfilter
added = bf.AddIfNotHasTS([]byte("butter")) // should be false because 'peanutbutter' is already in the set
added = bf.AddIfNotHasTS([]byte("peanutbuTTer")) // should be true because 'penutbuTTer' is new
// convert to JSON ([]byte)
Json := bf.JSONMarshal()
// bloomfilters Mutex is exposed for external un-/locking
// i.e. mutex lock while doing JSON conversion
bf.Mtx.Lock()
Json = bf.JSONMarshal()
bf.Mtx.Unlock()
// restore a bloom filter from storage
bfNew := bbloom.JSONUnmarshal(Json)
isInNew := bfNew.Has([]byte("butter")) // should be true
isNotInNew := bfNew.Has([]byte("Butter")) // should be false
```
to work with the bloom filter.
### why 'fast'?
It's about 3 times faster than William Fitzgeralds bitset bloom filter https://github.com/willf/bloom . And it is about so fast as my []bool set variant for Boom filters (see https://github.com/AndreasBriese/bloom ) but having a 8times smaller memory footprint:
Bloom filter (filter size 524288, 7 hashlocs)
github.com/AndreasBriese/bbloom 'Add' 65536 items (10 repetitions): 6595800 ns (100 ns/op)
github.com/AndreasBriese/bbloom 'Has' 65536 items (10 repetitions): 5986600 ns (91 ns/op)
github.com/AndreasBriese/bloom 'Add' 65536 items (10 repetitions): 6304684 ns (96 ns/op)
github.com/AndreasBriese/bloom 'Has' 65536 items (10 repetitions): 6568663 ns (100 ns/op)
github.com/willf/bloom 'Add' 65536 items (10 repetitions): 24367224 ns (371 ns/op)
github.com/willf/bloom 'Test' 65536 items (10 repetitions): 21881142 ns (333 ns/op)
github.com/dataence/bloom/standard 'Add' 65536 items (10 repetitions): 23041644 ns (351 ns/op)
github.com/dataence/bloom/standard 'Check' 65536 items (10 repetitions): 19153133 ns (292 ns/op)
github.com/cabello/bloom 'Add' 65536 items (10 repetitions): 131921507 ns (2012 ns/op)
github.com/cabello/bloom 'Contains' 65536 items (10 repetitions): 131108962 ns (2000 ns/op)
(on MBPro15 OSX10.8.5 i7 4Core 2.4Ghz)
With 32bit bloom filters (bloom32) using modified sdbm, bloom32 does hashing with only 2 bit shifts, one xor and one substraction per byte. smdb is about as fast as fnv64a but gives less collisions with the dataset (see mask above). bloom.New(float64(10 * 1<<16),float64(7)) populated with 1<<16 random items from the dataset (see above) and tested against the rest results in less than 0.05% collisions.

View File

@ -0,0 +1,284 @@
// The MIT License (MIT)
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// 2019/08/25 code revision to reduce unsafe use
// Parts are adopted from the fork at ipfs/bbloom after performance rev by
// Steve Allen (https://github.com/Stebalien)
// (see https://github.com/ipfs/bbloom/blob/master/bbloom.go)
// -> func Has
// -> func set
// -> func add
package bbloom
import (
"bytes"
"encoding/json"
"log"
"math"
"sync"
"unsafe"
)
// helper
// not needed anymore by Set
// var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128}
func getSize(ui64 uint64) (size uint64, exponent uint64) {
if ui64 < uint64(512) {
ui64 = uint64(512)
}
size = uint64(1)
for size < ui64 {
size <<= 1
exponent++
}
return size, exponent
}
func calcSizeByWrongPositives(numEntries, wrongs float64) (uint64, uint64) {
size := -1 * numEntries * math.Log(wrongs) / math.Pow(float64(0.69314718056), 2)
locs := math.Ceil(float64(0.69314718056) * size / numEntries)
return uint64(size), uint64(locs)
}
// New
// returns a new bloomfilter
func New(params ...float64) (bloomfilter Bloom) {
var entries, locs uint64
if len(params) == 2 {
if params[1] < 1 {
entries, locs = calcSizeByWrongPositives(params[0], params[1])
} else {
entries, locs = uint64(params[0]), uint64(params[1])
}
} else {
log.Fatal("usage: New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(3)) or New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(0.03))")
}
size, exponent := getSize(uint64(entries))
bloomfilter = Bloom{
Mtx: &sync.Mutex{},
sizeExp: exponent,
size: size - 1,
setLocs: locs,
shift: 64 - exponent,
}
bloomfilter.Size(size)
return bloomfilter
}
// NewWithBoolset
// takes a []byte slice and number of locs per entry
// returns the bloomfilter with a bitset populated according to the input []byte
func NewWithBoolset(bs *[]byte, locs uint64) (bloomfilter Bloom) {
bloomfilter = New(float64(len(*bs)<<3), float64(locs))
for i, b := range *bs {
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bloomfilter.bitset[0])) + uintptr(i))) = b
}
return bloomfilter
}
// bloomJSONImExport
// Im/Export structure used by JSONMarshal / JSONUnmarshal
type bloomJSONImExport struct {
FilterSet []byte
SetLocs uint64
}
// JSONUnmarshal
// takes JSON-Object (type bloomJSONImExport) as []bytes
// returns Bloom object
func JSONUnmarshal(dbData []byte) Bloom {
bloomImEx := bloomJSONImExport{}
json.Unmarshal(dbData, &bloomImEx)
buf := bytes.NewBuffer(bloomImEx.FilterSet)
bs := buf.Bytes()
bf := NewWithBoolset(&bs, bloomImEx.SetLocs)
return bf
}
//
// Bloom filter
type Bloom struct {
Mtx *sync.Mutex
ElemNum uint64
bitset []uint64
sizeExp uint64
size uint64
setLocs uint64
shift uint64
}
// <--- http://www.cse.yorku.ca/~oz/hash.html
// modified Berkeley DB Hash (32bit)
// hash is casted to l, h = 16bit fragments
// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) {
// hash := uint64(len(*b))
// for _, c := range *b {
// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash
// }
// h = hash >> bl.shift
// l = hash << bl.shift >> bl.shift
// return l, h
// }
// Update: found sipHash of Jean-Philippe Aumasson & Daniel J. Bernstein to be even faster than absdbm()
// https://131002.net/siphash/
// siphash was implemented for Go by Dmitry Chestnykh https://github.com/dchest/siphash
// Add
// set the bit(s) for entry; Adds an entry to the Bloom filter
func (bl *Bloom) Add(entry []byte) {
l, h := bl.sipHash(entry)
for i := uint64(0); i < bl.setLocs; i++ {
bl.set((h + i*l) & bl.size)
bl.ElemNum++
}
}
// AddTS
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
func (bl *Bloom) AddTS(entry []byte) {
bl.Mtx.Lock()
defer bl.Mtx.Unlock()
bl.Add(entry)
}
// Has
// check if bit(s) for entry is/are set
// returns true if the entry was added to the Bloom Filter
func (bl Bloom) Has(entry []byte) bool {
l, h := bl.sipHash(entry)
res := true
for i := uint64(0); i < bl.setLocs; i++ {
res = res && bl.isSet((h+i*l)&bl.size)
// https://github.com/ipfs/bbloom/commit/84e8303a9bfb37b2658b85982921d15bbb0fecff
// // Branching here (early escape) is not worth it
// // This is my conclusion from benchmarks
// // (prevents loop unrolling)
// switch bl.IsSet((h + i*l) & bl.size) {
// case false:
// return false
// }
}
return res
}
// HasTS
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
func (bl *Bloom) HasTS(entry []byte) bool {
bl.Mtx.Lock()
defer bl.Mtx.Unlock()
return bl.Has(entry)
}
// AddIfNotHas
// Only Add entry if it's not present in the bloomfilter
// returns true if entry was added
// returns false if entry was allready registered in the bloomfilter
func (bl Bloom) AddIfNotHas(entry []byte) (added bool) {
if bl.Has(entry) {
return added
}
bl.Add(entry)
return true
}
// AddIfNotHasTS
// Tread safe: Only Add entry if it's not present in the bloomfilter
// returns true if entry was added
// returns false if entry was allready registered in the bloomfilter
func (bl *Bloom) AddIfNotHasTS(entry []byte) (added bool) {
bl.Mtx.Lock()
defer bl.Mtx.Unlock()
return bl.AddIfNotHas(entry)
}
// Size
// make Bloom filter with as bitset of size sz
func (bl *Bloom) Size(sz uint64) {
bl.bitset = make([]uint64, sz>>6)
}
// Clear
// resets the Bloom filter
func (bl *Bloom) Clear() {
bs := bl.bitset
for i := range bs {
bs[i] = 0
}
}
// Set
// set the bit[idx] of bitsit
func (bl *Bloom) set(idx uint64) {
// ommit unsafe
// *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3))) |= mask[idx%8]
bl.bitset[idx>>6] |= 1 << (idx % 64)
}
// IsSet
// check if bit[idx] of bitset is set
// returns true/false
func (bl *Bloom) isSet(idx uint64) bool {
// ommit unsafe
// return (((*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)))) >> (idx % 8)) & 1) == 1
return bl.bitset[idx>>6]&(1<<(idx%64)) != 0
}
// JSONMarshal
// returns JSON-object (type bloomJSONImExport) as []byte
func (bl Bloom) JSONMarshal() []byte {
bloomImEx := bloomJSONImExport{}
bloomImEx.SetLocs = uint64(bl.setLocs)
bloomImEx.FilterSet = make([]byte, len(bl.bitset)<<3)
for i := range bloomImEx.FilterSet {
bloomImEx.FilterSet[i] = *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[0])) + uintptr(i)))
}
data, err := json.Marshal(bloomImEx)
if err != nil {
log.Fatal("json.Marshal failed: ", err)
}
return data
}
// // alternative hashFn
// func (bl Bloom) fnv64a(b *[]byte) (l, h uint64) {
// h64 := fnv.New64a()
// h64.Write(*b)
// hash := h64.Sum64()
// h = hash >> 32
// l = hash << 32 >> 32
// return l, h
// }
//
// // <-- http://partow.net/programming/hashfunctions/index.html
// // citation: An algorithm proposed by Donald E. Knuth in The Art Of Computer Programming Volume 3,
// // under the topic of sorting and search chapter 6.4.
// // modified to fit with boolset-length
// func (bl Bloom) DEKHash(b *[]byte) (l, h uint64) {
// hash := uint64(len(*b))
// for _, c := range *b {
// hash = ((hash << 5) ^ (hash >> bl.shift)) ^ uint64(c)
// }
// h = hash >> bl.shift
// l = hash << bl.sizeExp >> bl.sizeExp
// return l, h
// }

View File

@ -0,0 +1,225 @@
// Written in 2012 by Dmitry Chestnykh.
//
// To the extent possible under law, the author have dedicated all copyright
// and related and neighboring rights to this software to the public domain
// worldwide. This software is distributed without any warranty.
// http://creativecommons.org/publicdomain/zero/1.0/
//
// Package siphash implements SipHash-2-4, a fast short-input PRF
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
package bbloom
// Hash returns the 64-bit SipHash-2-4 of the given byte slice with two 64-bit
// parts of 128-bit key: k0 and k1.
func (bl Bloom) sipHash(p []byte) (l, h uint64) {
// Initialization.
v0 := uint64(8317987320269560794) // k0 ^ 0x736f6d6570736575
v1 := uint64(7237128889637516672) // k1 ^ 0x646f72616e646f6d
v2 := uint64(7816392314733513934) // k0 ^ 0x6c7967656e657261
v3 := uint64(8387220255325274014) // k1 ^ 0x7465646279746573
t := uint64(len(p)) << 56
// Compression.
for len(p) >= 8 {
m := uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
v3 ^= m
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
v0 ^= m
p = p[8:]
}
// Compress last block.
switch len(p) {
case 7:
t |= uint64(p[6]) << 48
fallthrough
case 6:
t |= uint64(p[5]) << 40
fallthrough
case 5:
t |= uint64(p[4]) << 32
fallthrough
case 4:
t |= uint64(p[3]) << 24
fallthrough
case 3:
t |= uint64(p[2]) << 16
fallthrough
case 2:
t |= uint64(p[1]) << 8
fallthrough
case 1:
t |= uint64(p[0])
}
v3 ^= t
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
v0 ^= t
// Finalization.
v2 ^= 0xff
// Round 1.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 2.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 3.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// Round 4.
v0 += v1
v1 = v1<<13 | v1>>51
v1 ^= v0
v0 = v0<<32 | v0>>32
v2 += v3
v3 = v3<<16 | v3>>48
v3 ^= v2
v0 += v3
v3 = v3<<21 | v3>>43
v3 ^= v0
v2 += v1
v1 = v1<<17 | v1>>47
v1 ^= v2
v2 = v2<<32 | v2>>32
// return v0 ^ v1 ^ v2 ^ v3
hash := v0 ^ v1 ^ v2 ^ v3
h = hash >> bl.shift
l = hash << bl.shift >> bl.shift
return l, h
}

View File

@ -0,0 +1,24 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org/>

View File

@ -0,0 +1,7 @@
# gopher-json [![GoDoc](https://godoc.org/layeh.com/gopher-json?status.svg)](https://godoc.org/layeh.com/gopher-json)
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
## License
Public domain

View File

@ -0,0 +1,33 @@
// Package json is a simple JSON encoder/decoder for gopher-lua.
//
// Documentation
//
// The following functions are exposed by the library:
// decode(string): Decodes a JSON string. Returns nil and an error string if
// the string could not be decoded.
// encode(value): Encodes a value into a JSON string. Returns nil and an error
// string if the value could not be encoded.
//
// The following types are supported:
//
// Lua | JSON
// ---------+-----
// nil | null
// number | number
// string | string
// table | object: when table is non-empty and has only string keys
// | array: when table is empty, or has only sequential numeric keys
// | starting from 1
//
// Attempting to encode any other Lua type will result in an error.
//
// Example
//
// Below is an example usage of the library:
// import (
// luajson "layeh.com/gopher-json"
// )
//
// L := lua.NewState()
// luajson.Preload(s)
package json

View File

@ -0,0 +1,189 @@
package json
import (
"encoding/json"
"errors"
"github.com/yuin/gopher-lua"
)
// Preload adds json to the given Lua state's package.preload table. After it
// has been preloaded, it can be loaded using require:
//
// local json = require("json")
func Preload(L *lua.LState) {
L.PreloadModule("json", Loader)
}
// Loader is the module loader function.
func Loader(L *lua.LState) int {
t := L.NewTable()
L.SetFuncs(t, api)
L.Push(t)
return 1
}
var api = map[string]lua.LGFunction{
"decode": apiDecode,
"encode": apiEncode,
}
func apiDecode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to decode"), 1)
return 0
}
str := L.CheckString(1)
value, err := Decode(L, []byte(str))
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(value)
return 1
}
func apiEncode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to encode"), 1)
return 0
}
value := L.CheckAny(1)
data, err := Encode(value)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(lua.LString(string(data)))
return 1
}
var (
errNested = errors.New("cannot encode recursively nested tables to JSON")
errSparseArray = errors.New("cannot encode sparse array")
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
)
type invalidTypeError lua.LValueType
func (i invalidTypeError) Error() string {
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
}
// Encode returns the JSON encoding of value.
func Encode(value lua.LValue) ([]byte, error) {
return json.Marshal(jsonValue{
LValue: value,
visited: make(map[*lua.LTable]bool),
})
}
type jsonValue struct {
lua.LValue
visited map[*lua.LTable]bool
}
func (j jsonValue) MarshalJSON() (data []byte, err error) {
switch converted := j.LValue.(type) {
case lua.LBool:
data, err = json.Marshal(bool(converted))
case lua.LNumber:
data, err = json.Marshal(float64(converted))
case *lua.LNilType:
data = []byte(`null`)
case lua.LString:
data, err = json.Marshal(string(converted))
case *lua.LTable:
if j.visited[converted] {
return nil, errNested
}
j.visited[converted] = true
key, value := converted.Next(lua.LNil)
switch key.Type() {
case lua.LTNil: // empty table
data = []byte(`[]`)
case lua.LTNumber:
arr := make([]jsonValue, 0, converted.Len())
expectedKey := lua.LNumber(1)
for key != lua.LNil {
if key.Type() != lua.LTNumber {
err = errInvalidKeys
return
}
if expectedKey != key {
err = errSparseArray
return
}
arr = append(arr, jsonValue{value, j.visited})
expectedKey++
key, value = converted.Next(key)
}
data, err = json.Marshal(arr)
case lua.LTString:
obj := make(map[string]jsonValue)
for key != lua.LNil {
if key.Type() != lua.LTString {
err = errInvalidKeys
return
}
obj[key.String()] = jsonValue{value, j.visited}
key, value = converted.Next(key)
}
data, err = json.Marshal(obj)
default:
err = errInvalidKeys
}
default:
err = invalidTypeError(j.LValue.Type())
}
return
}
// Decode converts the JSON encoded data to Lua values.
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
var value interface{}
err := json.Unmarshal(data, &value)
if err != nil {
return nil, err
}
return DecodeValue(L, value), nil
}
// DecodeValue converts the value to a Lua value.
//
// This function only converts values that the encoding/json package decodes to.
// All other values will return lua.LNil.
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
switch converted := value.(type) {
case bool:
return lua.LBool(converted)
case float64:
return lua.LNumber(converted)
case string:
return lua.LString(converted)
case json.Number:
return lua.LString(converted)
case []interface{}:
arr := L.CreateTable(len(converted), 0)
for _, item := range converted {
arr.Append(DecodeValue(L, item))
}
return arr
case map[string]interface{}:
tbl := L.CreateTable(0, len(converted))
for key, item := range converted {
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
}
return tbl
case nil:
return lua.LNil
}
return lua.LNil
}

View File

@ -0,0 +1,6 @@
/integration/redis_src/
/integration/dump.rdb
*.swp
/integration/nodes.conf
.idea/
miniredis.iml

View File

@ -0,0 +1,225 @@
## Changelog
### v2.23.0
- basic INFO support (thanks @kirill-a-belov)
- support COUNT in SSCAN (thanks @Abdi-dd)
- test and support Go 1.19
- support LPOS (thanks @ianstarz)
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
### v2.22.0
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
- fix invalid resposne of COMMAND (thanks @zsh1995)
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
### v2.21.0
- support for GETEX (thanks @dntj)
- support for GT and LT in ZADD (thanks @lsgndln)
- support for XAUTOCLAIM (thanks @randall-fulton)
### v2.20.0
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
### v2.19.0
- support for TYPE in SCAN (thanks @0xDiddi)
- update BITPOS (thanks @dirkm)
- fix a lua redis.call() return value (thanks @mpetronic)
- update ZRANGE (thanks @valdemarpereira)
### v2.18.0
- support for ZUNION (thanks @propan)
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
- support for LMOVE (thanks @btwear)
### v2.17.0
- added miniredis.RunT(t)
### v2.16.1
- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
- fix exclusive ranges in XRANGE (thanks @joseotoro)
### v2.16.0
- simplify some code (thanks @zonque)
- support for EXAT/PXAT in SET
- support for XTRIM (thanks @joseotoro)
- support for ZRANDMEMBER
- support for redis.log() in lua (thanks @dirkm)
### v2.15.2
- Fix race condition in blocking code (thanks @zonque and @robx)
- XREAD accepts '$' as ID (thanks @bradengroom)
### v2.15.1
- EVAL should cache the script (thanks @guoshimin)
### v2.15.0
- target redis 6.2 and added new args to various commands
- support for all hyperlog commands (thanks @ilbaktin)
- support for GETDEL (thanks @wszaranski)
### v2.14.5
- added XPENDING
- support for BLOCK option in XREAD and XREADGROUP
### v2.14.4
- fix BITPOS error (thanks @xiaoyuzdy)
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
- fix empty EXEC return type (thanks @ashanbrown)
- fix XDEL (thanks @svakili and @yvesf)
- fix FLUSHALL for streams (thanks @svakili)
### v2.14.3
- fix problem where Lua code didn't set the selected DB
- update to redis 6.0.10 (thanks @lazappa)
### v2.14.2
- update LUA dependency
- deal with (p)unsubscribe when there are no channels
### v2.14.1
- mod tidy
### v2.14.0
- support for HELLO and the RESP3 protocol
- KEEPTTL in SET (thanks @johnpena)
### v2.13.3
- support Go 1.14 and 1.15
- update the `Check...()` methods
- support for XREAD (thanks @pieterlexis)
### v2.13.2
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
- changed unit and integration tests to compare raw payloads, not parsed payloads
- remove "redigo" dependency
### v2.13.1
- added HSTRLEN
- minimal support for ACL users in AUTH
### v2.13.0
- added RunTLS(...)
- added SetError(...)
### v2.12.0
- redis 6
- Lua json update (thanks @gsmith85)
- CLUSTER commands (thanks @kratisto)
- fix TOUCH
- fix a shutdown race condition
### v2.11.4
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
### v2.11.3
- support for TOUCH (thanks @cleroux)
- support for cluster and stream commands (thanks @kak-tus)
### v2.11.2
- make sure Lua code is executed concurrently
- add command GEORADIUSBYMEMBER (thanks @kyeett)
### v2.11.1
- globals protection for Lua code (thanks @vk-outreach)
- HSET update (thanks @carlgreen)
- fix BLPOP block on shutdown (thanks @Asalle)
### v2.11.0
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
- added GEODIST
- improved precision for geohashes, closer to what real redis does
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
### v2.10.1
- added m.Server()
### v2.10.0
- added UNLINK
- fix DEL zero-argument case
- cleanup some direct access commands
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
### v2.9.1
- fix issue with ZRANGEBYLEX
- fix issue with BRPOPLPUSH and direct access
### v2.9.0
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
- fix messages generated by PSUBSCRIBE
- optional internal seed (thanks @zikaeroh)
### v2.8.0
Proper `v2` in go.mod.
### older
See https://github.com/alicebob/miniredis/releases for the full changelog

Some files were not shown because too many files have changed in this diff Show More