feat(mqtt): change proj structure to services and add mochi broker
This commit is contained in:
parent
26c523fc10
commit
46999043a5
1
backend/.gitignore → .gitignore
vendored
1
backend/.gitignore → .gitignore
vendored
|
|
@ -16,3 +16,4 @@ go.work
|
|||
*.txt
|
||||
*.pwd
|
||||
*.acl
|
||||
.idea
|
||||
|
|
@ -25,9 +25,9 @@ func main() {
|
|||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
log.Println("Starting Oktopus Project TR-369 Controller Version:", VERSION)
|
||||
flBroker := flag.Bool("m", false, "Defines if mosquitto container must run or not")
|
||||
//flBroker := flag.Bool("m", false, "Defines if mosquitto container must run or not")
|
||||
// fl_endpointId := flag.String("endpoint_id", "proto::oktopus-controller", "Defines the enpoint id the Agent must trust on.")
|
||||
flSubTopic := flag.String("s", "oktopus/+/+/agent", "That's the topic agent must publish to, and the controller keeps on listening.")
|
||||
flSubTopic := flag.String("s", "oktopus/+/agent/+", "That's the topic agent must publish to, and the controller keeps on listening.")
|
||||
// fl_pub_topic := flag.String("pub_topic", "oktopus/v1/controller", "That's the topic controller must publish to, and the agent keeps on listening.")
|
||||
flBrokerAddr := flag.String("a", "localhost", "Mqtt broker adrress")
|
||||
flBrokerPort := flag.String("p", "1883", "Mqtt broker port")
|
||||
|
|
@ -44,10 +44,10 @@ func main() {
|
|||
flag.Usage()
|
||||
os.Exit(0)
|
||||
}
|
||||
if *flBroker {
|
||||
log.Println("Starting Mqtt Broker")
|
||||
mqtt.StartMqttBroker()
|
||||
}
|
||||
//if *flBroker {
|
||||
// log.Println("Starting Mqtt Broker")
|
||||
// mqtt.StartMqttBroker()
|
||||
//}
|
||||
/*
|
||||
This context suppress our needs, but we can use a more sofisticate
|
||||
approach with cancel and timeout options passing it through paho mqtt functions.
|
||||
|
|
@ -69,8 +69,6 @@ func main() {
|
|||
CA: *flTlsCert,
|
||||
}
|
||||
|
||||
log.Println()
|
||||
|
||||
mtp.MtpService(&mqttClient, done)
|
||||
|
||||
<-done
|
||||
|
|
@ -81,6 +81,12 @@ func startClient(addr string, port string, tlsCa string, ctx context.Context, ms
|
|||
clientConfig := paho.ClientConfig{
|
||||
Conn: conn,
|
||||
Router: singleHandler,
|
||||
OnServerDisconnect: func(disconnect *paho.Disconnect) {
|
||||
log.Println("disconnected from mqtt server, reason code: ", disconnect.ReasonCode)
|
||||
},
|
||||
OnClientError: func(err error) {
|
||||
log.Println(err)
|
||||
},
|
||||
}
|
||||
return paho.NewClient(clientConfig)
|
||||
}
|
||||
|
|
@ -15,7 +15,6 @@ func StartMqttBroker() {
|
|||
|
||||
//TODO: Start Container through Docker SDK for GO, eliminating docker-compose and shell comands.
|
||||
//TODO: Create Broker with user, password and CA certificate.
|
||||
//TODO: Set broker access control list to topics.
|
||||
//TODO: Set MQTTv5 CONNACK packet with topic for agent to use.
|
||||
|
||||
cmd := exec.Command("sudo", "docker", "compose", "-f", "internal/mosquitto/docker-compose.yml", "up", "-d")
|
||||
43
backend/services/mochi/.github/workflows/build.yml
vendored
Normal file
43
backend/services/mochi/.github/workflows/build.yml
vendored
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
name: build
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.19
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -race ./... && echo true
|
||||
|
||||
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
if: success()
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.19.x
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Calc coverage
|
||||
run: |
|
||||
go test -v -covermode=count -coverprofile=coverage.out ./...
|
||||
- name: Convert coverage.out to coverage.lcov
|
||||
uses: jandelgado/gcov2lcov-action@v1.0.6
|
||||
- name: Coveralls
|
||||
uses: coverallsapp/github-action@v1.1.2
|
||||
with:
|
||||
github-token: ${{ secrets.github_token }}
|
||||
path-to-lcov: coverage.lcov
|
||||
3
backend/services/mochi/.gitignore
vendored
Normal file
3
backend/services/mochi/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
cmd/mqtt
|
||||
.DS_Store
|
||||
*.db
|
||||
103
backend/services/mochi/.golangci.yml
Normal file
103
backend/services/mochi/.golangci.yml
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
linters:
|
||||
disable-all: false
|
||||
fix: false # Fix found issues (if it's supported by the linter).
|
||||
enable:
|
||||
# - asasalint
|
||||
# - asciicheck
|
||||
# - bidichk
|
||||
# - bodyclose
|
||||
# - containedctx
|
||||
# - contextcheck
|
||||
#- cyclop
|
||||
# - deadcode
|
||||
- decorder
|
||||
# - depguard
|
||||
# - dogsled
|
||||
# - dupl
|
||||
- durationcheck
|
||||
# - errchkjson
|
||||
# - errname
|
||||
- errorlint
|
||||
# - execinquery
|
||||
# - exhaustive
|
||||
# - exhaustruct
|
||||
# - exportloopref
|
||||
#- forcetypeassert
|
||||
#- forbidigo
|
||||
#- funlen
|
||||
#- gci
|
||||
# - gochecknoglobals
|
||||
# - gochecknoinits
|
||||
# - gocognit
|
||||
# - goconst
|
||||
# - gocritic
|
||||
- gocyclo
|
||||
- godot
|
||||
# - godox
|
||||
# - goerr113
|
||||
# - gofmt
|
||||
# - gofumpt
|
||||
# - goheader
|
||||
- goimports
|
||||
# - golint
|
||||
# - gomnd
|
||||
# - gomoddirectives
|
||||
# - gomodguard
|
||||
# - goprintffuncname
|
||||
- gosec
|
||||
- gosimple
|
||||
- govet
|
||||
# - grouper
|
||||
# - ifshort
|
||||
- importas
|
||||
- ineffassign
|
||||
# - interfacebloat
|
||||
# - interfacer
|
||||
# - ireturn
|
||||
# - lll
|
||||
# - maintidx
|
||||
# - makezero
|
||||
- maligned
|
||||
- misspell
|
||||
# - nakedret
|
||||
# - nestif
|
||||
# - nilerr
|
||||
# - nilnil
|
||||
# - nlreturn
|
||||
# - noctx
|
||||
# - nolintlint
|
||||
# - nonamedreturns
|
||||
# - nosnakecase
|
||||
# - nosprintfhostport
|
||||
# - paralleltest
|
||||
# - prealloc
|
||||
# - predeclared
|
||||
# - promlinter
|
||||
- reassign
|
||||
# - revive
|
||||
# - rowserrcheck
|
||||
# - scopelint
|
||||
# - sqlclosecheck
|
||||
# - staticcheck
|
||||
# - structcheck
|
||||
# - stylecheck
|
||||
# - tagliatelle
|
||||
# - tenv
|
||||
# - testpackage
|
||||
# - thelper
|
||||
- tparallel
|
||||
# - typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usestdlibvars
|
||||
# - varcheck
|
||||
# - varnamelen
|
||||
- wastedassign
|
||||
- whitespace
|
||||
# - wrapcheck
|
||||
# - wsl
|
||||
disable:
|
||||
- errcheck
|
||||
|
||||
|
||||
31
backend/services/mochi/Dockerfile
Normal file
31
backend/services/mochi/Dockerfile
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
FROM golang:1.19.0-alpine3.15 AS builder
|
||||
|
||||
RUN apk update
|
||||
RUN apk add git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod ./
|
||||
COPY go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . ./
|
||||
|
||||
RUN go build -o /app/mochi ./cmd
|
||||
|
||||
|
||||
FROM alpine
|
||||
|
||||
WORKDIR /
|
||||
COPY --from=builder /app/mochi .
|
||||
|
||||
# tcp
|
||||
EXPOSE 1883
|
||||
|
||||
# websockets
|
||||
EXPOSE 1882
|
||||
|
||||
# dashboard
|
||||
EXPOSE 8080
|
||||
|
||||
ENTRYPOINT [ "/mochi" ]
|
||||
22
backend/services/mochi/LICENSE.md
Normal file
22
backend/services/mochi/LICENSE.md
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2019, 2022 Jonathan Blake (mochi-co)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
404
backend/services/mochi/README.md
Normal file
404
backend/services/mochi/README.md
Normal file
|
|
@ -0,0 +1,404 @@
|
|||
|
||||
<p align="center">
|
||||
|
||||

|
||||
[](https://coveralls.io/github/mochi-co/mqtt?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/mochi-co/mqtt/v2)
|
||||
[](https://pkg.go.dev/github.com/mochi-co/mqtt/v2)
|
||||
[](https://github.com/mochi-co/mqtt/issues)
|
||||
|
||||
</p>
|
||||
|
||||
# Mochi MQTT Broker
|
||||
## The fully compliant, embeddable high-performance Go MQTT v5 (and v3.1.1) broker server
|
||||
Mochi MQTT is an embeddable [fully compliant](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) MQTT v5 broker server written in Go, designed for the development of telemetry and internet-of-things projects. The server can be used either as a standalone binary or embedded as a library in your own applications, and has been designed to be as lightweight and fast as possible, with great care taken to ensure the quality and maintainability of the project.
|
||||
|
||||
### What is MQTT?
|
||||
MQTT stands for [MQ Telemetry Transport](https://en.wikipedia.org/wiki/MQTT). It is a publish/subscribe, extremely simple and lightweight messaging protocol, designed for constrained devices and low-bandwidth, high-latency or unreliable networks ([Learn more](https://mqtt.org/faq)). Mochi MQTT fully implements version 5.0.0 of the MQTT protocol.
|
||||
|
||||
## What's new in Version 2.0.0?
|
||||
Version 2.0.0 takes all the great things we loved about Mochi MQTT v1.0.0, learns from the mistakes, and improves on the things we wished we'd had. It's a total from-scratch rewrite, designed to fully implement MQTT v5 as a first-class feature.
|
||||
|
||||
Don't forget to use the new v2 import paths:
|
||||
```go
|
||||
import "github.com/mochi-co/mqtt/v2"
|
||||
```
|
||||
|
||||
- Full MQTTv5 Feature Compliance, compatibility for MQTT v3.1.1 and v3.0.0:
|
||||
- User and MQTTv5 Packet Properties
|
||||
- Topic Aliases
|
||||
- Shared Subscriptions
|
||||
- Subscription Options and Subscription Identifiers
|
||||
- Message Expiry
|
||||
- Client Session Expiry
|
||||
- Send and Receive QoS Flow Control Quotas
|
||||
- Server-side Disconnect and Auth Packets
|
||||
- Will Delay Intervals
|
||||
- Plus all the original MQTT features of Mochi MQTT v1, such as Full QoS(0,1,2), $SYS topics, retained messages, etc.
|
||||
- Developer-centric:
|
||||
- Most core broker code is now exported and accessible, for total developer control.
|
||||
- Full featured and flexible Hook-based interfacing system to provide easy 'plugin' development.
|
||||
- Direct Packet Injection using special inline client, or masquerade as existing clients.
|
||||
- Performant and Stable:
|
||||
- Our classic trie-based Topic-Subscription model.
|
||||
- Client-specific write buffers to avoid issues with slow-reading or irregular client behaviour.
|
||||
- Passes all [Paho Interoperability Tests](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) for MQTT v5 and MQTT v3.
|
||||
- Over a thousand carefully considered unit test scenarios.
|
||||
- TCP, Websocket (including SSL/TLS), and $SYS Dashboard listeners.
|
||||
- Built-in Redis, Badger, and Bolt Persistence using Hooks (but you can also make your own).
|
||||
- Built-in Rule-based Authentication and ACL Ledger using Hooks (also make your own).
|
||||
|
||||
> There is no upgrade path from v1.0.0. Please review the documentation and this readme to get a sense of the changes required (e.g. the v1 events system, auth, and persistence have all been replaced with the new hooks system).
|
||||
|
||||
### Compatibility Notes
|
||||
Because of the overlap between the v5 specification and previous versions of mqtt, the server can accept both v5 and v3 clients, but note that in cases where both v5 an v3 clients are connected, properties and features provided for v5 clients will be downgraded for v3 clients (such as user properties).
|
||||
|
||||
Support for MQTT v3.0.0 and v3.1.1 is considered hybrid-compatibility. Where not specifically restricted in the v3 specification, more modern and safety-first v5 behaviours are used instead - such as expiry for inflight and retained messages, and clients - and quality-of-service flow control limits.
|
||||
|
||||
## Roadmap
|
||||
- Please [open an issue](https://github.com/mochi-co/mqtt/issues) to request new features or event hooks!
|
||||
- Cluster support.
|
||||
- Enhanced Metrics support.
|
||||
- File-based server configuration (supporting docker).
|
||||
|
||||
## Quick Start
|
||||
### Running the Broker with Go
|
||||
Mochi MQTT can be used as a standalone broker. Simply checkout this repository and run the [cmd/main.go](cmd/main.go) entrypoint in the [cmd](cmd) folder which will expose tcp (:1883), websocket (:1882), and dashboard (:8080) listeners.
|
||||
|
||||
```
|
||||
cd cmd
|
||||
go build -o mqtt && ./mqtt
|
||||
```
|
||||
|
||||
### Using Docker
|
||||
A simple Dockerfile is provided for running the [cmd/main.go](cmd/main.go) Websocket, TCP, and Stats server:
|
||||
|
||||
```sh
|
||||
docker build -t mochi:latest .
|
||||
docker run -p 1883:1883 -p 1882:1882 -p 8080:8080 mochi:latest
|
||||
```
|
||||
|
||||
## Developing with Mochi MQTT
|
||||
### Importing as a package
|
||||
Importing Mochi MQTT as a package requires just a few lines of code to get started.
|
||||
``` go
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Create the new MQTT Server.
|
||||
server := mqtt.New(nil)
|
||||
|
||||
// Allow all connections.
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
// Create a TCP listener on a standard port.
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Examples of running the broker with various configurations can be found in the [examples](examples) folder.
|
||||
|
||||
#### Network Listeners
|
||||
The server comes with a variety of pre-packaged network listeners which allow the broker to accept connections on different protocols. The current listeners are:
|
||||
|
||||
| Listener | Usage |
|
||||
| --- | --- |
|
||||
| listeners.NewTCP | A TCP listener |
|
||||
| listeners.NewUnixSock | A Unix Socket listener |
|
||||
| listeners.NewNet | A net.Listener listener |
|
||||
| listeners.NewWebsocket | A Websocket listener |
|
||||
| listeners.NewHTTPStats | An HTTP $SYS info dashboard |
|
||||
|
||||
> Use the `listeners.Listener` interface to develop new listeners. If you do, please let us know!
|
||||
|
||||
A `*listeners.Config` may be passed to configure TLS.
|
||||
|
||||
Examples of usage can be found in the [examples](examples) folder or [cmd/main.go](cmd/main.go).
|
||||
|
||||
### Server Options and Capabilities
|
||||
A number of configurable options are available which can be used to alter the behaviour or restrict access to certain features in the server.
|
||||
|
||||
```go
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
Capabilities: mqtt.Capabilities{
|
||||
ClientNetWriteBufferSize: 4096,
|
||||
ClientNetReadBufferSize: 4096,
|
||||
MaximumSessionExpiryInterval: 3600,
|
||||
Compatibilities: mqtt.Compatibilities{
|
||||
ObscureNotAuthorized: true,
|
||||
},
|
||||
},
|
||||
SysTopicResendInterval: 10,
|
||||
})
|
||||
```
|
||||
|
||||
Review the mqtt.Options, mqtt.Capabilities, and mqtt.Compatibilities structs for a comprehensive list of options. `ClientNetWriteBufferSize` and `ClientNetReadBufferSize` can be configured to adjust memory usage per client, based on your needs.
|
||||
|
||||
|
||||
## Event Hooks
|
||||
A universal event hooks system allows developers to hook into various parts of the server and client life cycle to add and modify functionality of the broker. These universal hooks are used to provide everything from authentication, persistent storage, to debugging tools.
|
||||
|
||||
Hooks are stackable - you can add multiple hooks to a server, and they will be run in the order they were added. Some hooks modify values, and these modified values will be passed to the subsequent hooks before being returned to the runtime code.
|
||||
|
||||
| Type | Import | Info |
|
||||
| -- | -- | -- |
|
||||
| Access Control | [mochi-co/mqtt/hooks/auth . AllowHook](hooks/auth/allow_all.go) | Allow access to all connecting clients and read/write to all topics. |
|
||||
| Access Control | [mochi-co/mqtt/hooks/auth . Auth](hooks/auth/auth.go) | Rule-based access control ledger. |
|
||||
| Persistence | [mochi-co/mqtt/hooks/storage/bolt](hooks/storage/bolt/bolt.go) | Persistent storage using [BoltDB](https://dbdb.io/db/boltdb) (deprecated). |
|
||||
| Persistence | [mochi-co/mqtt/hooks/storage/badger](hooks/storage/badger/badger.go) | Persistent storage using [BadgerDB](https://github.com/dgraph-io/badger). |
|
||||
| Persistence | [mochi-co/mqtt/hooks/storage/redis](hooks/storage/redis/redis.go) | Persistent storage using [Redis](https://redis.io). |
|
||||
| Debugging | [mochi-co/mqtt/hooks/debug](hooks/debug/debug.go) | Additional debugging output to visualise packet flow. |
|
||||
|
||||
Many of the internal server functions are now exposed to developers, so you can make your own Hooks by using the above as examples. If you do, please [Open an issue](https://github.com/mochi-co/mqtt/issues) and let everyone know!
|
||||
|
||||
### Access Control
|
||||
#### Allow Hook
|
||||
By default, Mochi MQTT uses a DENY-ALL access control rule. To allow connections, this must overwritten using an Access Control hook. The simplest of these hooks is the `auth.AllowAll` hook, which provides ALLOW-ALL rules to all connections, subscriptions, and publishing. It's also the simplest hook to use:
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
```
|
||||
|
||||
> Don't do this if you are exposing your server to the internet or untrusted networks - it should really be used for development, testing, and debugging only.
|
||||
|
||||
#### Auth Ledger
|
||||
The Auth Ledger hook provides a sophisticated mechanism for defining access rules in a struct format. Auth ledger rules come in two forms: Auth rules (connection), and ACL rules (publish subscribe).
|
||||
|
||||
Auth rules have 4 optional criteria and an assertion flag:
|
||||
| Criteria | Usage |
|
||||
| -- | -- |
|
||||
| Client | client id of the connecting client |
|
||||
| Username | username of the connecting client |
|
||||
| Password | password of the connecting client |
|
||||
| Remote | the remote address or ip of the client |
|
||||
| Allow | true (allow this user) or false (deny this user) |
|
||||
|
||||
ACL rules have 3 optional criteria and an filter match:
|
||||
| Criteria | Usage |
|
||||
| -- | -- |
|
||||
| Client | client id of the connecting client |
|
||||
| Username | username of the connecting client |
|
||||
| Remote | the remote address or ip of the client |
|
||||
| Filters | an array of filters to match |
|
||||
|
||||
Rules are processed in index order (0,1,2,3), returning on the first matching rule. See [hooks/auth/ledger.go](hooks/auth/ledger.go) to review the structs.
|
||||
|
||||
```go
|
||||
server := mqtt.New(nil)
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Ledger: &auth.Ledger{
|
||||
Auth: auth.AuthRules{ // Auth disallows all by default
|
||||
{Username: "peach", Password: "password1", Allow: true},
|
||||
{Username: "melon", Password: "password2", Allow: true},
|
||||
{Remote: "127.0.0.1:*", Allow: true},
|
||||
{Remote: "localhost:*", Allow: true},
|
||||
},
|
||||
ACL: auth.ACLRules{ // ACL allows all by default
|
||||
{Remote: "127.0.0.1:*"}, // local superuser allow all
|
||||
{
|
||||
// user melon can read and write to their own topic
|
||||
Username: "melon", Filters: auth.Filters{
|
||||
"melon/#": auth.ReadWrite,
|
||||
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
|
||||
},
|
||||
},
|
||||
{
|
||||
// Otherwise, no clients have publishing permissions
|
||||
Filters: auth.Filters{
|
||||
"#": auth.ReadOnly,
|
||||
"updates/#": auth.Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
The ledger can also be stored as JSON or YAML and loaded using the Data field:
|
||||
```go
|
||||
err = server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // build ledger from byte slice: yaml or json
|
||||
})
|
||||
```
|
||||
See [examples/auth/encoded/main.go](examples/auth/encoded/main.go) for more information.
|
||||
|
||||
### Persistent Storage
|
||||
#### Redis
|
||||
A basic Redis storage hook is available which provides persistence for the broker. It can be added to the server in the same fashion as any other hook, with several options. It uses github.com/go-redis/redis/v8 under the hook, and is completely configurable through the Options value.
|
||||
```go
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
Addr: "localhost:6379", // default redis address
|
||||
Password: "", // your password
|
||||
DB: 0, // your redis db
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
For more information on how the redis hook works, or how to use it, see the [examples/persistence/redis/main.go](examples/persistence/redis/main.go) or [hooks/storage/redis](hooks/storage/redis) code.
|
||||
|
||||
#### Badger DB
|
||||
There's also a BadgerDB storage hook if you prefer file based storage. It can be added and configured in much the same way as the other hooks (with somewhat less options).
|
||||
```go
|
||||
err := server.AddHook(new(badger.Hook), &badger.Options{
|
||||
Path: badgerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
For more information on how the badger hook works, or how to use it, see the [examples/persistence/badger/main.go](examples/persistence/badger/main.go) or [hooks/storage/badger](hooks/storage/badger) code.
|
||||
|
||||
There is also a BoltDB hook which has been deprecated in favour of Badger, but if you need it, check [examples/persistence/bolt/main.go](examples/persistence/bolt/main.go).
|
||||
|
||||
|
||||
|
||||
## Developing with Event Hooks
|
||||
Many hooks are available for interacting with the broker and client lifecycle.
|
||||
The function signatures for all the hooks and `mqtt.Hook` interface can be found in [hooks.go](hooks.go).
|
||||
|
||||
> The most flexible event hooks are OnPacketRead, OnPacketEncode, and OnPacketSent - these hooks be used to control and modify all incoming and outgoing packets.
|
||||
|
||||
| Function | Usage |
|
||||
| -------------------------- | -- |
|
||||
| OnStarted | Called when the server has successfully started.|
|
||||
| OnStopped | Called when the server has successfully stopped. |
|
||||
| OnConnectAuthenticate | Called when a user attempts to authenticate with the server. An implementation of this method MUST be used to allow or deny access to the server (see hooks/auth/allow_all or basic). It can be used in custom hooks to check connecting users against an existing user database. Returns true if allowed. |
|
||||
| OnACLCheck | Called when a user attempts to publish or subscribe to a topic filter. As above. |
|
||||
| OnSysInfoTick | Called when the $SYS topic values are published out. |
|
||||
| OnConnect | Called when a new client connects |
|
||||
| OnSessionEstablished | Called when a new client successfully establishes a session (after OnConnect) |
|
||||
| OnDisconnect | Called when a client is disconnected for any reason. |
|
||||
| OnAuthPacket | Called when an auth packet is received. It is intended to allow developers to create their own mqtt v5 Auth Packet handling mechanisms. Allows packet modification. |
|
||||
| OnPacketRead | Called when a packet is received from a client. Allows packet modification. |
|
||||
| OnPacketEncode | Called immediately before a packet is encoded to be sent to a client. Allows packet modification. |
|
||||
| OnPacketSent | Called when a packet has been sent to a client. |
|
||||
| OnPacketProcessed | Called when a packet has been received and successfully handled by the broker. |
|
||||
| OnSubscribe | Called when a client subscribes to one or more filters. Allows packet modification. |
|
||||
| OnSubscribed | Called when a client successfully subscribes to one or more filters. |
|
||||
| OnSelectSubscribers | Called when subscribers have been collected for a topic, but before shared subscription subscribers have been selected. Allows receipient modification.|
|
||||
| OnUnsubscribe | Called when a client unsubscribes from one or more filters. Allows packet modification. |
|
||||
| OnUnsubscribed | Called when a client successfully unsubscribes from one or more filters. |
|
||||
| OnPublish | Called when a client publishes a message. Allows packet modification. |
|
||||
| OnPublished | Called when a client has published a message to subscribers. |
|
||||
| OnPublishDropped | Called when a message to a client is dropped before delivery, such as if the client is taking too long to respond. |
|
||||
| OnRetainMessage | Called then a published message is retained. |
|
||||
| OnQosPublish | Called when a publish packet with Qos >= 1 is issued to a subscriber. |
|
||||
| OnQosComplete | Called when the Qos flow for a message has been completed. |
|
||||
| OnQosDropped | Called when an inflight message expires before completion. |
|
||||
| OnWill | Called when a client disconnects and intends to issue a will message. Allows packet modification. |
|
||||
| OnWillSent | Called when an LWT message has been issued from a disconnecting client. |
|
||||
| OnClientExpired | Called when a client session has expired and should be deleted. |
|
||||
| OnRetainedExpired | Called when a retained message has expired and should be deleted. |
|
||||
| StoredClients | Returns clients, eg. from a persistent store. |
|
||||
| StoredSubscriptions | Returns client subscriptions, eg. from a persistent store. |
|
||||
| StoredInflightMessages | Returns inflight messages, eg. from a persistent store. |
|
||||
| StoredRetainedMessages | Returns retained messages, eg. from a persistent store. |
|
||||
| StoredSysInfo | Returns stored system info values, eg. from a persistent store. |
|
||||
|
||||
If you are building a persistent storage hook, see the existing persistent hooks for inspiration and patterns. If you are building an auth hook, you will need `OnACLCheck` and `OnConnectAuthenticate`.
|
||||
|
||||
|
||||
### Direct Publish
|
||||
To publish basic message to a topic from within the embedding application, you can use the `server.Publish(topic string, payload []byte, retain bool, qos byte) error` method.
|
||||
|
||||
```go
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
```
|
||||
> The Qos byte in this case is only used to set the upper qos limit available for subscribers, as per MQTT v5 spec.
|
||||
|
||||
### Packet Injection
|
||||
If you want more control, or want to set specific MQTT v5 properties and other values you can create your own publish packets from a client of your choice. This method allows you to inject MQTT packets (no just publish) directly into the runtime as though they had been received by a specific client. Most of the time you'll want to use the special client flag `inline=true`, as it has unique privileges: it bypasses all ACL and topic validation checks, meaning it can even publish to $SYS topics.
|
||||
|
||||
Packet injection can be used for any MQTT packet, including ping requests, subscriptions, etc. And because the Clients structs and methods are now exported, you can even inject packets on behalf of a connected client (if you have a very custom requirements).
|
||||
|
||||
```go
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("scheduled message"),
|
||||
})
|
||||
```
|
||||
|
||||
> MQTT packets still need to be correctly formed, so refer our [the test packets catalogue](packets/tpackets.go) and [MQTTv5 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) for inspiration.
|
||||
|
||||
See the [hooks example](examples/hooks/main.go) to see this feature in action.
|
||||
|
||||
|
||||
|
||||
### Testing
|
||||
#### Unit Tests
|
||||
Mochi MQTT tests over a thousand scenarios with thoughtfully hand written unit tests to ensure each function does exactly what we expect. You can run the tests using go:
|
||||
```
|
||||
go run --cover ./...
|
||||
```
|
||||
|
||||
#### Paho Interoperability Test
|
||||
You can check the broker against the [Paho Interoperability Test](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) by starting the broker using `examples/paho/main.go`, and then running the mqtt v5 and v3 tests with `python3 client_test5.py` from the _interoperability_ folder.
|
||||
|
||||
> Note that there are currently a number of outstanding issues regarding false negatives in the paho suite, and as such, certain compatibility modes are enabled in the `paho/main.go` example.
|
||||
|
||||
|
||||
## Performance Benchmarks
|
||||
Mochi MQTT performance is comparable with popular brokers such as Mosquitto, EMQX, and others.
|
||||
|
||||
Performance benchmarks were tested using [MQTT-Stresser](https://github.com/inovex/mqtt-stresser) on a Apple Macbook Air M2, using `cmd/main.go` default settings. Taking into account bursts of high and low throughput, the median scores are the most useful. Higher is better.
|
||||
|
||||
> The values presented in the benchmark are not representative of true messages per second throughput. They rely on an unusual calculation by mqtt-stresser, but are usable as they are consistent across all brokers.
|
||||
> Benchmarks are provided as a general performance expectation guideline only.
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=2 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.0 | 127,216 | 125,748 | 124,279 | 319,250 | 309,327 | 299,405 |
|
||||
| Mosquitto v2.0.15 | 155,920 | 155,919 | 155,918 | 185,485 | 185,097 | 184,709 |
|
||||
| EMQX v5.0.11 | 156,945 | 156,257 | 155,568 | 17,918 | 17,783 | 17,649 |
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=10 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.0 | 45,615 | 30,129 | 21,138 | 232,717 | 86,323 | 50,402 |
|
||||
| Mosquitto v2.0.15 | 42,729 | 38,633 | 29,879 | 23,241 | 19,714 | 18,806 |
|
||||
| EMQX v5.0.11 | 21,553 | 17,418 | 14,356 | 4,257 | 3,980 | 3,756 |
|
||||
|
||||
Million Message Challenge (hit the server with 1 million messages immediately):
|
||||
|
||||
`mqtt-stresser -broker tcp://localhost:1883 -num-clients=100 -num-messages=10000`
|
||||
| Broker | publish fastest | median | slowest | receive fastest | median | slowest |
|
||||
| -- | -- | -- | -- | -- | -- | -- |
|
||||
| Mochi v2.2.0 | 51,044 | 4,682 | 2,345 | 72,634 | 7,645 | 2,464 |
|
||||
| Mosquitto v2.0.15 | 3,826 | 3,395 | 3,032 | 1,200 | 1,150 | 1,118 |
|
||||
| EMQX v5.0.11 | 4,086 | 2,432 | 2,274 | 434 | 333 | 311 |
|
||||
|
||||
> Not sure what's going on with EMQX here, perhaps the docker out-of-the-box settings are not optimal, so take it with a pinch of salt as we know for a fact it's a solid piece of software.
|
||||
|
||||
## Stargazers over time 🥰
|
||||
[](https://starchart.cc/mochi-co/mqtt)
|
||||
Are you using Mochi MQTT in a project? [Let us know!](https://github.com/mochi-co/mqtt/issues)
|
||||
|
||||
## Contributions
|
||||
Contributions and feedback are both welcomed and encouraged! [Open an issue](https://github.com/mochi-co/mqtt/issues) to report a bug, ask a question, or make a feature request.
|
||||
|
||||
|
||||
|
||||
568
backend/services/mochi/clients.go
Normal file
568
backend/services/mochi/clients.go
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds
|
||||
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
|
||||
)
|
||||
|
||||
// ReadFn is the function signature for the function used for reading and processing new packets.
|
||||
type ReadFn func(*Client, packets.Packet) error
|
||||
|
||||
// Clients contains a map of the clients known by the broker.
|
||||
type Clients struct {
|
||||
internal map[string]*Client // clients known by the broker, keyed on client id.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClients returns an instance of Clients.
|
||||
func NewClients() *Clients {
|
||||
return &Clients{
|
||||
internal: make(map[string]*Client),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new client to the clients map, keyed on client id.
|
||||
func (cl *Clients) Add(val *Client) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
cl.internal[val.ID] = val
|
||||
}
|
||||
|
||||
// GetAll returns all the clients.
|
||||
func (cl *Clients) GetAll() map[string]*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
m := map[string]*Client{}
|
||||
for k, v := range cl.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns the value of a client if it exists.
|
||||
func (cl *Clients) Get(id string) (*Client, bool) {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
val, ok := cl.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the clients map.
|
||||
func (cl *Clients) Len() int {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
val := len(cl.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a client from the internal map.
|
||||
func (cl *Clients) Delete(id string) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
delete(cl.internal, id)
|
||||
}
|
||||
|
||||
// GetByListener returns clients matching a listener id.
|
||||
func (cl *Clients) GetByListener(id string) []*Client {
|
||||
cl.RLock()
|
||||
defer cl.RUnlock()
|
||||
clients := make([]*Client, 0, cl.Len())
|
||||
for _, client := range cl.internal {
|
||||
if client.Net.Listener == id && atomic.LoadUint32(&client.State.done) == 0 {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
}
|
||||
return clients
|
||||
}
|
||||
|
||||
// Client contains information about a client known by the broker.
|
||||
type Client struct {
|
||||
Properties ClientProperties // client properties
|
||||
State ClientState // the operational state of the client.
|
||||
Net ClientConnection // network connection state of the clinet
|
||||
ID string // the client id.
|
||||
ops *ops // ops provides a reference to server ops.
|
||||
sync.RWMutex // mutex
|
||||
}
|
||||
|
||||
// ClientConnection contains the connection transport and metadata for the client.
|
||||
type ClientConnection struct {
|
||||
Conn net.Conn // the net.Conn used to establish the connection
|
||||
bconn *bufio.ReadWriter // a buffered net.Conn for reading packets
|
||||
Remote string // the remote address of the client
|
||||
Listener string // listener id of the client
|
||||
Inline bool // client is an inline programmetic client
|
||||
}
|
||||
|
||||
// ClientProperties contains the properties which define the client behaviour.
|
||||
type ClientProperties struct {
|
||||
Props packets.Properties
|
||||
Will Will
|
||||
Username []byte
|
||||
ProtocolVersion byte
|
||||
Clean bool
|
||||
}
|
||||
|
||||
// Will contains the last will and testament details for a client connection.
|
||||
type Will struct {
|
||||
Payload []byte // -
|
||||
User []packets.UserProperty // -
|
||||
TopicName string // -
|
||||
Flag uint32 // 0,1
|
||||
WillDelayInterval uint32 // -
|
||||
Qos byte // -
|
||||
Retain bool // -
|
||||
}
|
||||
|
||||
// State tracks the state of the client.
|
||||
type ClientState struct {
|
||||
TopicAliases TopicAliases // a map of topic aliases
|
||||
stopCause atomic.Value // reason for stopping
|
||||
Inflight *Inflight // a map of in-flight qos messages
|
||||
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
|
||||
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
|
||||
outbound chan *packets.Packet // queue for pending outbound packets
|
||||
endOnce sync.Once // only end once
|
||||
isTakenOver uint32 // used to identify orphaned clients
|
||||
packetID uint32 // the current highest packetID
|
||||
done uint32 // atomic counter which indicates that the client has closed
|
||||
outboundQty int32 // number of messages currently in the outbound queue
|
||||
keepalive uint16 // the number of seconds the connection can wait
|
||||
}
|
||||
|
||||
// newClient returns a new instance of Client. This is almost exclusively used by Server
|
||||
// for creating new clients, but it lives here because it's not dependent.
|
||||
func newClient(c net.Conn, o *ops) *Client {
|
||||
cl := &Client{
|
||||
State: ClientState{
|
||||
Inflight: NewInflights(),
|
||||
Subscriptions: NewSubscriptions(),
|
||||
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
|
||||
keepalive: defaultKeepalive,
|
||||
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
|
||||
},
|
||||
Properties: ClientProperties{
|
||||
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
|
||||
},
|
||||
ops: o,
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
cl.Net = ClientConnection{
|
||||
Conn: c,
|
||||
bconn: bufio.NewReadWriter(
|
||||
bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
|
||||
bufio.NewWriterSize(c, o.options.ClientNetReadBufferSize),
|
||||
),
|
||||
Remote: c.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
|
||||
func (cl *Client) WriteLoop() {
|
||||
for pk := range cl.State.outbound {
|
||||
if err := cl.WritePacket(*pk); err != nil {
|
||||
cl.ops.log.Debug().Err(err).Str("client", cl.ID).Interface("packet", pk).Msg("failed publishing packet")
|
||||
}
|
||||
atomic.AddInt32(&cl.State.outboundQty, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseConnect parses the connect parameters and properties for a client.
|
||||
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
|
||||
cl.Net.Listener = lid
|
||||
|
||||
cl.Properties.ProtocolVersion = pk.ProtocolVersion
|
||||
cl.Properties.Username = pk.Connect.Username
|
||||
cl.Properties.Clean = pk.Connect.Clean
|
||||
cl.Properties.Props = pk.Properties.Copy(false)
|
||||
|
||||
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
|
||||
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
|
||||
|
||||
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
|
||||
|
||||
cl.ID = pk.Connect.ClientIdentifier
|
||||
if cl.ID == "" {
|
||||
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
|
||||
cl.Properties.Props.AssignedClientID = cl.ID
|
||||
}
|
||||
|
||||
cl.State.keepalive = cl.ops.options.Capabilities.ServerKeepAlive
|
||||
if pk.Connect.Keepalive > 0 {
|
||||
cl.State.keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
|
||||
}
|
||||
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will = Will{
|
||||
Qos: pk.Connect.WillQos,
|
||||
Retain: pk.Connect.WillRetain,
|
||||
Payload: pk.Connect.WillPayload,
|
||||
TopicName: pk.Connect.WillTopic,
|
||||
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
|
||||
User: pk.Connect.WillProperties.User,
|
||||
}
|
||||
if pk.Properties.SessionExpiryIntervalFlag &&
|
||||
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
|
||||
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
|
||||
}
|
||||
if pk.Connect.WillFlag {
|
||||
cl.Properties.Will.Flag = 1 // atomic for checking
|
||||
}
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
}
|
||||
|
||||
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (cl *Client) refreshDeadline(keepalive uint16) {
|
||||
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
|
||||
if keepalive > 0 {
|
||||
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
|
||||
}
|
||||
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
|
||||
}
|
||||
}
|
||||
|
||||
// NextPacketID returns the next available (unused) packet id for the client.
|
||||
// If no unused packet ids are available, an error is returned and the client
|
||||
// should be disconnected.
|
||||
func (cl *Client) NextPacketID() (i uint32, err error) {
|
||||
cl.Lock()
|
||||
defer cl.Unlock()
|
||||
|
||||
i = atomic.LoadUint32(&cl.State.packetID)
|
||||
started := i
|
||||
overflowed := false
|
||||
for {
|
||||
if overflowed && i == started {
|
||||
return 0, packets.ErrQuotaExceeded
|
||||
}
|
||||
|
||||
if i >= cl.ops.options.Capabilities.maximumPacketID {
|
||||
overflowed = true
|
||||
i = 0
|
||||
continue
|
||||
}
|
||||
|
||||
i++
|
||||
|
||||
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
|
||||
atomic.StoreUint32(&cl.State.packetID, i)
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
|
||||
func (cl *Client) ResendInflightMessages(force bool) error {
|
||||
if cl.State.Inflight.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if tk.FixedHeader.Type == packets.Publish {
|
||||
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
|
||||
}
|
||||
|
||||
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
|
||||
err := cl.WritePacket(tk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosComplete(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearInflights deletes all inflight messages for the client, eg. for a disconnected user with a clean session.
|
||||
func (cl *Client) ClearInflights(now, maximumExpiry int64) []uint16 {
|
||||
deleted := []uint16{}
|
||||
for _, tk := range cl.State.Inflight.GetAll(false) {
|
||||
if (tk.Expiry > 0 && tk.Expiry < now) || tk.Created+maximumExpiry < now {
|
||||
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
|
||||
cl.ops.hooks.OnQosDropped(cl, tk)
|
||||
atomic.AddInt64(&cl.ops.info.Inflight, -1)
|
||||
deleted = append(deleted, tk.PacketID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return deleted
|
||||
}
|
||||
|
||||
// Read reads incoming packets from the connected client and transforms them into
|
||||
// packets to be handled by the packetHandler.
|
||||
func (cl *Client) Read(packetHandler ReadFn) error {
|
||||
var err error
|
||||
|
||||
for {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cl.refreshDeadline(cl.State.keepalive)
|
||||
fh := new(packets.FixedHeader)
|
||||
err = cl.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = packetHandler(cl, pk) // Process inbound packet.
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop instructs the client to shut down all processing goroutines and disconnect.
|
||||
func (cl *Client) Stop(err error) {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
cl.State.endOnce.Do(func() {
|
||||
if cl.Net.Conn != nil {
|
||||
_ = cl.Net.Conn.Close() // omit close error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cl.State.stopCause.Store(err)
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&cl.State.done, 1)
|
||||
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
|
||||
// StopCause returns the reason the client connection was stopped, if any.
|
||||
func (cl *Client) StopCause() error {
|
||||
if cl.State.stopCause.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
return cl.State.stopCause.Load().(error)
|
||||
}
|
||||
|
||||
// Closed returns true if client connection is closed.
|
||||
func (cl *Client) Closed() bool {
|
||||
return atomic.LoadUint32(&cl.State.done) == 1
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
if cl.Net.bconn == nil {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
b, err := cl.Net.bconn.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fh.Decode(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var bu int
|
||||
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadPacket reads the remaining buffer into an MQTT packet.
|
||||
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
|
||||
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
|
||||
|
||||
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
|
||||
pk.FixedHeader = *fh
|
||||
p := make([]byte, pk.FixedHeader.Remaining)
|
||||
n, err := io.ReadFull(cl.Net.bconn, p)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
|
||||
|
||||
// Decode the remaining packet values using a fresh copy of the bytes,
|
||||
// otherwise the next packet will change the data of this one.
|
||||
px := append([]byte{}, p[:]...)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectDecode(px)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectDecode(px)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackDecode(px)
|
||||
case packets.Publish:
|
||||
err = pk.PublishDecode(px)
|
||||
if err == nil {
|
||||
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
|
||||
}
|
||||
case packets.Puback:
|
||||
err = pk.PubackDecode(px)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecDecode(px)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelDecode(px)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompDecode(px)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeDecode(px)
|
||||
case packets.Suback:
|
||||
err = pk.SubackDecode(px)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(px)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackDecode(px)
|
||||
case packets.Pingreq:
|
||||
case packets.Pingresp:
|
||||
case packets.Auth:
|
||||
err = pk.AuthDecode(px)
|
||||
default:
|
||||
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
|
||||
return
|
||||
}
|
||||
|
||||
// WritePacket encodes and writes a packet to the client.
|
||||
func (cl *Client) WritePacket(pk packets.Packet) error {
|
||||
if atomic.LoadUint32(&cl.State.done) == 1 {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
if cl.Net.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pk.Expiry > 0 {
|
||||
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
|
||||
}
|
||||
|
||||
pk.ProtocolVersion = cl.Properties.ProtocolVersion
|
||||
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
|
||||
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
|
||||
}
|
||||
|
||||
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
|
||||
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
|
||||
}
|
||||
|
||||
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
|
||||
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
|
||||
}
|
||||
|
||||
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
|
||||
|
||||
var err error
|
||||
buf := new(bytes.Buffer)
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case packets.Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case packets.Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
case packets.Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case packets.Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case packets.Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case packets.Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case packets.Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case packets.Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case packets.Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case packets.Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case packets.Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case packets.Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case packets.Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
case packets.Auth:
|
||||
err = pk.AuthEncode(buf)
|
||||
default:
|
||||
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
|
||||
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
|
||||
}
|
||||
|
||||
nb := net.Buffers{buf.Bytes()}
|
||||
n, err := nb.WriteTo(cl.Net.Conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.AddInt64(&cl.ops.info.BytesSent, n)
|
||||
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
|
||||
if pk.FixedHeader.Type == packets.Publish {
|
||||
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
|
||||
}
|
||||
|
||||
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
|
||||
|
||||
return err
|
||||
}
|
||||
745
backend/services/mochi/clients_test.go
Normal file
745
backend/services/mochi/clients_test.go
Normal file
|
|
@ -0,0 +1,745 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const pkInfo = "packet type %v, %s"
|
||||
|
||||
var errClientStop = errors.New("test stop")
|
||||
|
||||
func newTestClient() (cl *Client, r net.Conn, w net.Conn) {
|
||||
r, w = net.Pipe()
|
||||
|
||||
cl = newClient(w, &ops{
|
||||
info: new(system.Info),
|
||||
hooks: new(Hooks),
|
||||
log: &logger,
|
||||
options: &Options{
|
||||
Capabilities: &Capabilities{
|
||||
ReceiveMaximum: 10,
|
||||
TopicAliasMaximum: 10000,
|
||||
MaximumClientWritesPending: 3,
|
||||
maximumPacketID: 10,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
cl.ID = "mochi"
|
||||
cl.State.Inflight.maximumSendQuota = 5
|
||||
cl.State.Inflight.sendQuota = 5
|
||||
cl.State.Inflight.maximumReceiveQuota = 10
|
||||
cl.State.Inflight.receiveQuota = 10
|
||||
cl.Properties.Props.TopicAliasMaximum = 0
|
||||
cl.Properties.Props.RequestResponseInfo = 0x1
|
||||
|
||||
go cl.WriteLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestNewInflights(t *testing.T) {
|
||||
require.NotNil(t, NewInflights().internal)
|
||||
}
|
||||
|
||||
func TestNewClients(t *testing.T) {
|
||||
cl := NewClients()
|
||||
require.NotNil(t, cl.internal)
|
||||
}
|
||||
|
||||
func TestClientsAdd(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
}
|
||||
|
||||
func TestClientsGet(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
client, ok := cl.Get("t1")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, "t1", client.ID)
|
||||
}
|
||||
|
||||
func TestClientsGetAll(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
cl.Add(&Client{ID: "t3"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Contains(t, cl.internal, "t3")
|
||||
|
||||
clients := cl.GetAll()
|
||||
require.Len(t, clients, 3)
|
||||
}
|
||||
|
||||
func TestClientsLen(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
cl.Add(&Client{ID: "t2"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
require.Equal(t, 2, cl.Len())
|
||||
}
|
||||
|
||||
func TestClientsDelete(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1"})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
|
||||
cl.Delete("t1")
|
||||
_, ok := cl.Get("t1")
|
||||
require.Equal(t, false, ok)
|
||||
require.Nil(t, cl.internal["t1"])
|
||||
}
|
||||
|
||||
func TestClientsGetByListener(t *testing.T) {
|
||||
cl := NewClients()
|
||||
cl.Add(&Client{ID: "t1", Net: ClientConnection{Listener: "tcp1"}})
|
||||
cl.Add(&Client{ID: "t2", Net: ClientConnection{Listener: "ws1"}})
|
||||
require.Contains(t, cl.internal, "t1")
|
||||
require.Contains(t, cl.internal, "t2")
|
||||
|
||||
clients := cl.GetByListener("tcp1")
|
||||
require.NotEmpty(t, clients)
|
||||
require.Equal(t, 1, len(clients))
|
||||
require.Equal(t, "tcp1", clients[0].Net.Listener)
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
require.NotNil(t, cl)
|
||||
require.NotNil(t, cl.State.Inflight.internal)
|
||||
require.NotNil(t, cl.State.Subscriptions)
|
||||
require.NotNil(t, cl.State.TopicAliases)
|
||||
require.Equal(t, defaultKeepalive, cl.State.keepalive)
|
||||
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
|
||||
require.NotNil(t, cl.Net.Conn)
|
||||
require.NotNil(t, cl.Net.bconn)
|
||||
require.NotNil(t, cl.ops)
|
||||
require.NotNil(t, cl.ops.options.Capabilities)
|
||||
require.False(t, cl.Net.Inline)
|
||||
}
|
||||
|
||||
func TestClientParseConnect(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillTopic: "lwt",
|
||||
WillPayload: []byte("lol gg"),
|
||||
WillQos: 1,
|
||||
WillRetain: false,
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
ReceiveMaximum: uint16(5),
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.Keepalive, cl.State.keepalive)
|
||||
require.Equal(t, pk.Connect.Clean, cl.Properties.Clean)
|
||||
require.Equal(t, pk.Connect.ClientIdentifier, cl.ID)
|
||||
require.Equal(t, pk.Connect.WillTopic, cl.Properties.Will.TopicName)
|
||||
require.Equal(t, pk.Connect.WillPayload, cl.Properties.Will.Payload)
|
||||
require.Equal(t, pk.Connect.WillQos, cl.Properties.Will.Qos)
|
||||
require.Equal(t, pk.Connect.WillRetain, cl.Properties.Will.Retain)
|
||||
require.Equal(t, uint32(1), cl.Properties.Will.Flag)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.receiveQuota)
|
||||
require.Equal(t, int32(cl.ops.options.Capabilities.ReceiveMaximum), cl.State.Inflight.maximumReceiveQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.sendQuota)
|
||||
require.Equal(t, int32(pk.Properties.ReceiveMaximum), cl.State.Inflight.maximumSendQuota)
|
||||
}
|
||||
|
||||
func TestClientParseConnectOverrideWillDelay(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
pk := packets.Packet{
|
||||
ProtocolVersion: 4,
|
||||
Connect: packets.ConnectParams{
|
||||
ProtocolName: []byte{'M', 'Q', 'T', 'T'},
|
||||
Clean: true,
|
||||
Keepalive: 60,
|
||||
ClientIdentifier: "mochi",
|
||||
WillFlag: true,
|
||||
WillProperties: packets.Properties{
|
||||
WillDelayInterval: 200,
|
||||
},
|
||||
},
|
||||
Properties: packets.Properties{
|
||||
SessionExpiryInterval: 100,
|
||||
SessionExpiryIntervalFlag: true,
|
||||
},
|
||||
}
|
||||
|
||||
cl.ParseConnect("tcp1", pk)
|
||||
require.Equal(t, pk.Properties.SessionExpiryInterval, cl.Properties.Will.WillDelayInterval)
|
||||
}
|
||||
|
||||
func TestClientParseConnectNoID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.ParseConnect("tcp1", packets.Packet{})
|
||||
require.NotEmpty(t, cl.ID)
|
||||
}
|
||||
|
||||
func TestClientNextPacketID(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(2), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDInUse(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
// skip over 2
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(3), i)
|
||||
|
||||
// Skip over overflow
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 65535})
|
||||
atomic.StoreUint32(&cl.State.packetID, 65534)
|
||||
|
||||
i, err = cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDExhausted(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(1); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{PacketID: uint16(i)}
|
||||
}
|
||||
|
||||
i, err := cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
require.Equal(t, uint32(0), i)
|
||||
}
|
||||
|
||||
func TestClientNextPacketIDOverflow(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
for i := uint32(0); i < cl.ops.options.Capabilities.maximumPacketID; i++ {
|
||||
cl.State.Inflight.internal[uint16(i)] = packets.Packet{}
|
||||
}
|
||||
|
||||
cl.State.packetID = uint32(cl.ops.options.Capabilities.maximumPacketID - 1)
|
||||
i, err := cl.NextPacketID()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ops.options.Capabilities.maximumPacketID, i)
|
||||
cl.State.Inflight.internal[uint16(cl.ops.options.Capabilities.maximumPacketID)] = packets.Packet{}
|
||||
|
||||
cl.State.packetID = cl.ops.options.Capabilities.maximumPacketID
|
||||
_, err = cl.NextPacketID()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
|
||||
}
|
||||
|
||||
func TestClientClearInflights(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
n := time.Now().Unix()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
|
||||
require.Equal(t, 5, cl.State.Inflight.Len())
|
||||
|
||||
deleted := cl.ClearInflights(n, 4)
|
||||
require.Len(t, deleted, 3)
|
||||
require.ElementsMatch(t, []uint16{1, 2, 5}, deleted)
|
||||
require.Equal(t, 2, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessages(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Puback].Get(packets.TPuback)
|
||||
cl, r, w := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
go func() {
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, cl.State.Inflight.Len())
|
||||
require.Equal(t, pk1.RawBytes, buf)
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesWriteFailure(t *testing.T) {
|
||||
pk1 := packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup)
|
||||
cl, r, _ := newTestClient()
|
||||
r.Close()
|
||||
|
||||
cl.State.Inflight.Set(*pk1.Packet)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestClientResendInflightMessagesNoMessages(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.ResendInflightMessages(true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientRefreshDeadline(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.refreshDeadline(10)
|
||||
require.NotNil(t, cl.Net.Conn) // how do we check net.Conn deadline?
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeader(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0x00})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.BytesReceived))
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderDecodeError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect<<4 | 1<<1, 0x00, 0x00})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderPacketOversized(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
cl.ops.options.Capabilities.MaximumPacketSize = 2
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write(packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes)
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrPacketTooLarge)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderReadEOF(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderNoLengthTerminator(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
go func() {
|
||||
r.Write([]byte{packets.Connect << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 18, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'a', '/', 'b', '/', 'c', // Topic Name
|
||||
'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload,
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
var pks []packets.Packet
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
pks = append(pks, pk)
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
err := <-o
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
require.Equal(t, 2, len(pks))
|
||||
require.Equal(t, []packets.Packet{
|
||||
{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 18,
|
||||
},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello mochi"),
|
||||
},
|
||||
{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 11,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("yeah"),
|
||||
},
|
||||
}, pks)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&cl.ops.info.MessagesReceived))
|
||||
}
|
||||
|
||||
func TestClientReadDone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
cl.State.done = 1
|
||||
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
o <- cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
require.NoError(t, <-o)
|
||||
}
|
||||
|
||||
func TestClientStop(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(nil)
|
||||
require.Equal(t, nil, cl.State.stopCause.Load())
|
||||
require.Equal(t, time.Now().Unix(), cl.State.disconnected)
|
||||
require.Equal(t, uint32(1), cl.State.done)
|
||||
require.Equal(t, nil, cl.StopCause())
|
||||
}
|
||||
|
||||
func TestClientClosed(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
require.False(t, cl.Closed())
|
||||
cl.Stop(nil)
|
||||
require.True(t, cl.Closed())
|
||||
}
|
||||
|
||||
func TestClientReadFixedHeaderError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
cl.Net.bconn = nil
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrConnectionClosed, err)
|
||||
}
|
||||
|
||||
func TestClientReadReadHandlerErr(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5, // Topic Name - LSB+MSB
|
||||
'd', '/', 'e', '/', 'f', // Topic Name
|
||||
'y', 'e', 'a', 'h', // Payload
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
err := cl.Read(func(cl *Client, pk packets.Packet) error {
|
||||
return errors.New("test")
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadReadPacketOK(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
packets.Publish << 4, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pk)
|
||||
|
||||
require.Equal(t, packets.Packet{
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
Remaining: 11,
|
||||
},
|
||||
TopicName: "d/e/f",
|
||||
Payload: []byte("yeah"),
|
||||
}, pk)
|
||||
}
|
||||
|
||||
func TestClientReadPacket(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
for _, tx := range pkTable {
|
||||
tt := tx // avoid data race
|
||||
t.Run(tt.Desc, func(t *testing.T) {
|
||||
atomic.StoreInt64(&cl.ops.info.PacketsReceived, 0)
|
||||
go func() {
|
||||
r.Write(tt.RawBytes)
|
||||
}()
|
||||
|
||||
fh := new(packets.FixedHeader)
|
||||
err := cl.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.Packet.ProtocolVersion == 5 {
|
||||
cl.Properties.ProtocolVersion = 5
|
||||
} else {
|
||||
cl.Properties.ProtocolVersion = 0
|
||||
}
|
||||
|
||||
pk, err := cl.ReadPacket(fh)
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
require.NotNil(t, pk, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, *tt.Packet, pk, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
if tt.Packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsReceived), pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadPacketInvalidTypeError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid packet type")
|
||||
}
|
||||
|
||||
func TestClientWritePacket(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
|
||||
cl.Properties.ProtocolVersion = tt.Packet.ProtocolVersion
|
||||
|
||||
o := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
o <- buf
|
||||
}()
|
||||
|
||||
err := cl.WritePacket(*tt.Packet)
|
||||
require.NoError(t, err, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
require.Equal(t, tt.RawBytes, <-o, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
cl.Stop(errClientStop)
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
|
||||
// The stop cause is either the test error, EOF, or a
|
||||
// closed pipe, depending on which goroutine runs first.
|
||||
err = cl.StopCause()
|
||||
require.True(t,
|
||||
errors.Is(err, errClientStop) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, io.ErrClosedPipe))
|
||||
|
||||
require.Equal(t, int64(len(tt.RawBytes)), atomic.LoadInt64(&cl.ops.info.BytesSent))
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.PacketsSent))
|
||||
if tt.Packet.FixedHeader.Type == packets.Publish {
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&cl.ops.info.MessagesSent))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClientOversizePacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Properties.Props.MaximumPacketSize = 2
|
||||
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishDropOversize).Packet
|
||||
err := cl.WritePacket(pk)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, packets.ErrPacketTooLarge, err)
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadingError(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Type: 0,
|
||||
Remaining: 11,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientReadPacketReadUnknown(t *testing.T) {
|
||||
cl, r, _ := newTestClient()
|
||||
defer cl.Stop(errClientStop)
|
||||
go func() {
|
||||
r.Write([]byte{
|
||||
0, 11, // Fixed header
|
||||
0, 5,
|
||||
'd', '/', 'e', '/', 'f',
|
||||
'y', 'e', 'a', 'h',
|
||||
})
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
_, err := cl.ReadPacket(&packets.FixedHeader{
|
||||
Remaining: 1,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteNoConn(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Stop(errClientStop)
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, ErrConnectionClosed, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketWriteError(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.Net.Conn.Close()
|
||||
|
||||
err := cl.WritePacket(*pkTable[1].Packet)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientWritePacketInvalidPacket(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
err := cl.WritePacket(packets.Packet{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
var (
|
||||
pkTable = []packets.TPacketCase{
|
||||
packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311),
|
||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedMqtt5),
|
||||
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession),
|
||||
packets.TPacketData[packets.Publish].Get(packets.TPublishBasic),
|
||||
packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5),
|
||||
packets.TPacketData[packets.Puback].Get(packets.TPuback),
|
||||
packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
|
||||
packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
|
||||
packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
|
||||
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe),
|
||||
packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5),
|
||||
packets.TPacketData[packets.Suback].Get(packets.TSuback),
|
||||
packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5),
|
||||
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribe),
|
||||
packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5),
|
||||
packets.TPacketData[packets.Unsuback].Get(packets.TUnsuback),
|
||||
packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5),
|
||||
packets.TPacketData[packets.Pingreq].Get(packets.TPingreq),
|
||||
packets.TPacketData[packets.Pingresp].Get(packets.TPingresp),
|
||||
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect),
|
||||
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5),
|
||||
packets.TPacketData[packets.Auth].Get(packets.TAuth),
|
||||
}
|
||||
)
|
||||
52
backend/services/mochi/cmd/auth.json
Normal file
52
backend/services/mochi/cmd/auth.json
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
{
|
||||
"auth": [
|
||||
{
|
||||
"username": "leandro",
|
||||
"password": "leandro",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"username": "steve",
|
||||
"password": "steve",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"username": "root",
|
||||
"password": "root",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"remote": "*",
|
||||
"allow": false
|
||||
},
|
||||
{
|
||||
"remote": "*",
|
||||
"allow": false
|
||||
}
|
||||
],
|
||||
"acl": [
|
||||
{
|
||||
"remote": "*"
|
||||
},
|
||||
{
|
||||
"username": "leandro",
|
||||
"filters": {
|
||||
"oktopus/+/agent/+": 1,
|
||||
"oktopus/+/controller/+": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"username": "steve",
|
||||
"filters": {
|
||||
"oktopus/+/agent/+": 1,
|
||||
"oktopus/+/controller/+": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"username": "root",
|
||||
"filters": {
|
||||
"#": 3
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
103
backend/services/mochi/cmd/main.go
Normal file
103
backend/services/mochi/cmd/main.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/rs/zerolog"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
|
||||
wsAddr := flag.String("ws", "", "network address for Websocket listener")
|
||||
infoAddr := flag.String("info", "", "network address for web info dashboard listener")
|
||||
path := flag.String("path", "", "path to data auth file")
|
||||
flag.Parse()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(&mqtt.Options{
|
||||
//Capabilities: &mqtt.Capabilities{
|
||||
// ServerKeepAlive: 10000,
|
||||
// ReceiveMaximum: math.MaxUint16,
|
||||
// MaximumMessageExpiryInterval: math.MaxUint32,
|
||||
// MaximumSessionExpiryInterval: math.MaxUint32, // maximum number of seconds to keep disconnected sessions
|
||||
// MaximumClientWritesPending: 65536,
|
||||
// MaximumPacketSize: 0,
|
||||
// MaximumQos: 2,
|
||||
//},
|
||||
})
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
if *path != "" {
|
||||
data, err := os.ReadFile(*path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
err := server.AddHook(new(auth.AllowHook), nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if *tcpAddr != "" {
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if *wsAddr != "" {
|
||||
ws := listeners.NewWebsocket("ws1", *wsAddr, nil)
|
||||
err := server.AddListener(ws)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if *infoAddr != "" {
|
||||
stats := listeners.NewHTTPStats("stats", *infoAddr, nil, server.Info)
|
||||
err := server.AddListener(stats)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
83
backend/services/mochi/examples/auth/basic/main.go
Normal file
83
backend/services/mochi/examples/auth/basic/main.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
authRules := &auth.Ledger{
|
||||
Auth: auth.AuthRules{ // Auth disallows all by default
|
||||
{Username: "peach", Password: "password1", Allow: true},
|
||||
{Username: "melon", Password: "password2", Allow: true},
|
||||
{Remote: "127.0.0.1:*", Allow: true},
|
||||
{Remote: "localhost:*", Allow: true},
|
||||
},
|
||||
ACL: auth.ACLRules{ // ACL allows all by default
|
||||
{Remote: "127.0.0.1:*"}, // local superuser allow all
|
||||
{
|
||||
// user melon can read and write to their own topic
|
||||
Username: "melon", Filters: auth.Filters{
|
||||
"melon/#": auth.ReadWrite,
|
||||
"updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others
|
||||
},
|
||||
},
|
||||
{
|
||||
// Otherwise, no clients have publishing permissions
|
||||
Filters: auth.Filters{
|
||||
"#": auth.ReadOnly,
|
||||
"updates/#": auth.Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// you may also find this useful...
|
||||
// d, _ := authRules.ToYAML()
|
||||
// d, _ := authRules.ToJSON()
|
||||
// fmt.Println(string(d))
|
||||
|
||||
server := mqtt.New(nil)
|
||||
err := server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Ledger: authRules,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
52
backend/services/mochi/examples/auth/encoded/auth.json
Normal file
52
backend/services/mochi/examples/auth/encoded/auth.json
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
{
|
||||
"auth": [
|
||||
{
|
||||
"username": "leandro",
|
||||
"password": "leandro",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"username": "steve",
|
||||
"password": "steve",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"username": "root",
|
||||
"password": "root",
|
||||
"allow": true
|
||||
},
|
||||
{
|
||||
"remote": "*",
|
||||
"allow": false
|
||||
},
|
||||
{
|
||||
"remote": "*",
|
||||
"allow": false
|
||||
}
|
||||
],
|
||||
"acl": [
|
||||
{
|
||||
"remote": "*"
|
||||
},
|
||||
{
|
||||
"username": "leandro",
|
||||
"filters": {
|
||||
"oktopus/+/agent/+": 1,
|
||||
"oktopus/+/controller/+": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"username": "steve",
|
||||
"filters": {
|
||||
"oktopus/+/agent/+": 1,
|
||||
"oktopus/+/controller/+": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"username": "root",
|
||||
"filters": {
|
||||
"#": 3
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
21
backend/services/mochi/examples/auth/encoded/auth.yaml
Normal file
21
backend/services/mochi/examples/auth/encoded/auth.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
auth:
|
||||
- username: peach
|
||||
password: password1
|
||||
allow: true
|
||||
- username: melon
|
||||
password: password2
|
||||
allow: true
|
||||
# - remote: 127.0.0.1:*
|
||||
# allow: true
|
||||
# - remote: localhost:*
|
||||
# allow: true
|
||||
acl:
|
||||
# 0 = deny, 1 = read only, 2 = write only, 3 = read and write
|
||||
- remote: 127.0.0.1:*
|
||||
- username: melon
|
||||
filters:
|
||||
melon/#: 3
|
||||
updates/#: 2
|
||||
- filters:
|
||||
'#': 1
|
||||
updates/#: 0
|
||||
65
backend/services/mochi/examples/auth/encoded/main.go
Normal file
65
backend/services/mochi/examples/auth/encoded/main.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// You can also run from top-level server.go folder:
|
||||
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml
|
||||
// go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json
|
||||
path := flag.String("path", "auth.yaml", "path to data auth file")
|
||||
flag.Parse()
|
||||
|
||||
// Get ledger from yaml file
|
||||
data, err := os.ReadFile(*path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
server := mqtt.New(nil)
|
||||
err = server.AddHook(new(auth.Hook), &auth.Options{
|
||||
Data: data, // build ledger from byte slice, yaml or json
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
52
backend/services/mochi/examples/benchmark/main.go
Normal file
52
backend/services/mochi/examples/benchmark/main.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener")
|
||||
flag.Parse()
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
tcp := listeners.NewTCP("t1", *tcpAddr, nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
62
backend/services/mochi/examples/debug/main.go
Normal file
62
backend/services/mochi/examples/debug/main.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/debug"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
err := server.AddHook(new(auth.AllowHook), nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(debug.Hook), &debug.Options{
|
||||
// ShowPacketData: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
143
backend/services/mochi/examples/hooks/main.go
Normal file
143
backend/services/mochi/examples/hooks/main.go
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddHook(new(ExampleHook), map[string]any{})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start the server
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Demonstration of directly publishing messages to a topic via the
|
||||
// `server.Publish` method. Subscribe to `direct/publish` using your
|
||||
// MQTT client to see the messages.
|
||||
go func() {
|
||||
cl := server.NewClient(nil, "local", "inline", true)
|
||||
for range time.Tick(time.Second * 1) {
|
||||
err := server.InjectPacket(cl, packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Type: packets.Publish,
|
||||
},
|
||||
TopicName: "direct/publish",
|
||||
Payload: []byte("injected scheduled message"),
|
||||
})
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.InjectPacket")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go injected packet to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
// There is also a shorthand convenience function, Publish, for easily sending
|
||||
// publish packets if you are not concerned with creating your own packets.
|
||||
go func() {
|
||||
for range time.Tick(time.Second * 5) {
|
||||
err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0)
|
||||
if err != nil {
|
||||
server.Log.Error().Err(err).Msg("server.Publish")
|
||||
}
|
||||
server.Log.Info().Msgf("main.go issued direct message to direct/publish")
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
||||
type ExampleHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
func (h *ExampleHook) ID() string {
|
||||
return "events-example"
|
||||
}
|
||||
|
||||
func (h *ExampleHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnect,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnPublished,
|
||||
mqtt.OnPublish,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
func (h *ExampleHook) Init(config any) error {
|
||||
h.Log.Info().Msg("initialised")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Msgf("client connected")
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
|
||||
h.Log.Info().Str("client", cl.ID).Bool("expire", expire).Err(err).Msg("client disconnected")
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msgf("subscribed qos=%v", reasonCodes)
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Interface("filters", pk.Filters).Msg("unsubscribed")
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("received from client")
|
||||
|
||||
pkx := pk
|
||||
if string(pk.Payload) == "hello" {
|
||||
pkx.Payload = []byte("hello world")
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pkx.Payload)).Msg("received modified packet from client")
|
||||
}
|
||||
|
||||
return pkx, nil
|
||||
}
|
||||
|
||||
func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Info().Str("client", cl.ID).Str("payload", string(pk.Payload)).Msg("published to client")
|
||||
}
|
||||
74
backend/services/mochi/examples/paho.testing/main.go
Normal file
74
backend/services/mochi/examples/paho.testing/main.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
server.Options.Capabilities.ServerKeepAlive = 60
|
||||
server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
|
||||
server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true
|
||||
|
||||
_ = server.AddHook(new(pahoAuthHook), nil)
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err := server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
|
||||
type pahoAuthHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) ID() string {
|
||||
return "allow-all-auth"
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return topic != "test/nosubscribe"
|
||||
}
|
||||
59
backend/services/mochi/examples/persistence/badger/main.go
Normal file
59
backend/services/mochi/examples/persistence/badger/main.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/badger"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
badgerPath := ".badger"
|
||||
defer os.RemoveAll(badgerPath) // remove the example badger files at the end
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
err := server.AddHook(new(badger.Hook), &badger.Options{
|
||||
Path: badgerPath,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
}
|
||||
60
backend/services/mochi/examples/persistence/bolt/main.go
Normal file
60
backend/services/mochi/examples/persistence/bolt/main.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/bolt"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
err := server.AddHook(new(bolt.Hook), &bolt.Options{
|
||||
Path: "bolt.db",
|
||||
Options: &bbolt.Options{
|
||||
Timeout: 500 * time.Millisecond,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
65
backend/services/mochi/examples/persistence/redis/main.go
Normal file
65
backend/services/mochi/examples/persistence/redis/main.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage/redis"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
rv8 "github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
l := server.Log.Level(zerolog.DebugLevel)
|
||||
server.Log = &l
|
||||
|
||||
err := server.AddHook(new(redis.Hook), &redis.Options{
|
||||
Options: &rv8.Options{
|
||||
Addr: "localhost:6379", // default redis address
|
||||
Password: "", // your password
|
||||
DB: 0, // your redis db
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
|
||||
}
|
||||
58
backend/services/mochi/examples/tcp/main.go
Normal file
58
backend/services/mochi/examples/tcp/main.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// An example of configuring various server options...
|
||||
options := &mqtt.Options{
|
||||
// InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages
|
||||
}
|
||||
|
||||
server := mqtt.New(options)
|
||||
|
||||
// For security reasons, the default implementation disallows all connections.
|
||||
// If you want to allow all connections, you must specifically allow it.
|
||||
err := server.AddHook(new(auth.AllowHook), nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", nil)
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
117
backend/services/mochi/examples/tls/main.go
Normal file
117
backend/services/mochi/examples/tls/main.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
var (
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
// Optionally, if you want clients to authenticate only with certs issued by your CA,
|
||||
// you might want to use something like this:
|
||||
// certPool := x509.NewCertPool()
|
||||
// _ = certPool.AppendCertsFromPEM(caCertPem)
|
||||
// tlsConfig := &tls.Config{
|
||||
// ClientCAs: certPool,
|
||||
// ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// }
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
tcp := listeners.NewTCP("t1", ":1883", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
err = server.AddListener(tcp)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
err = server.AddListener(ws)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stats := listeners.NewHTTPStats("stats", ":8080", &listeners.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
}, nil)
|
||||
err = server.AddListener(stats)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
47
backend/services/mochi/examples/websocket/main.go
Normal file
47
backend/services/mochi/examples/websocket/main.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/auth"
|
||||
"github.com/mochi-co/mqtt/v2/listeners"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sigs := make(chan os.Signal, 1)
|
||||
done := make(chan bool, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigs
|
||||
done <- true
|
||||
}()
|
||||
|
||||
server := mqtt.New(nil)
|
||||
_ = server.AddHook(new(auth.AllowHook), nil)
|
||||
|
||||
ws := listeners.NewWebsocket("ws1", ":1882", nil)
|
||||
err := server.AddListener(ws)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
server.Log.Warn().Msg("caught signal, stopping...")
|
||||
server.Close()
|
||||
server.Log.Info().Msg("main.go finished")
|
||||
}
|
||||
40
backend/services/mochi/go.mod
Normal file
40
backend/services/mochi/go.mod
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
module github.com/mochi-co/mqtt/v2
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.23.0
|
||||
github.com/asdine/storm v2.1.2+incompatible
|
||||
github.com/asdine/storm/v3 v3.2.1
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.3.5
|
||||
github.com/rs/xid v1.4.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/timshannon/badgerhold v1.0.0
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/badger v1.6.0 // indirect
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/golang/snappy v0.0.3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
|
||||
golang.org/x/net v0.7.0 // indirect
|
||||
golang.org/x/sys v0.5.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
)
|
||||
143
backend/services/mochi/go.sum
Normal file
143
backend/services/mochi/go.sum
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M=
|
||||
github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM=
|
||||
github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
|
||||
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863 h1:BRrxwOZBolJN4gIwvZMJY1tzqBvQgpaZiQRuIDD40jM=
|
||||
github.com/Sereal/Sereal v0.0.0-20190618215532-0b8ac451a863/go.mod h1:D0JMgToj/WdxCgd30Kc1UcA9E+WdZoJqeVOuYW7iTBM=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.23.0 h1:+lwAJYjvvdIVg6doFHuotFjueJ/7KY10xo/vm3X3Scw=
|
||||
github.com/alicebob/miniredis/v2 v2.23.0/go.mod h1:XNqvJdQJv5mSuVMc0ynneafpnL/zv52acZ6kqeS0t88=
|
||||
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||
github.com/asdine/storm v2.1.2+incompatible h1:dczuIkyqwY2LrtXPz8ixMrU/OFgZp71kbKTHGrXYt/Q=
|
||||
github.com/asdine/storm v2.1.2+incompatible/go.mod h1:RarYDc9hq1UPLImuiXK3BIWPJLdIygvV3PsInK0FbVQ=
|
||||
github.com/asdine/storm/v3 v3.2.1 h1:I5AqhkPK6nBZ/qJXySdI7ot5BlXSZ7qvDY1zAn5ZJac=
|
||||
github.com/asdine/storm/v3 v3.2.1/go.mod h1:LEpXwGt4pIqrE/XcTvCnZHT5MgZCV6Ub9q7yQzOFWr0=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
|
||||
github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk=
|
||||
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgraph-io/badger v1.6.0 h1:DshxFxZWXUcO0xX476VJC07Xsr6ZCBVRHKZ93Oh7Evo=
|
||||
github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4=
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA=
|
||||
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg=
|
||||
github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
|
||||
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
|
||||
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
|
||||
github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
|
||||
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
|
||||
github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU=
|
||||
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
|
||||
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
|
||||
github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/timshannon/badgerhold v1.0.0 h1:LtqnDRVP7294FWRiZCIfQa6Tt0bGmlzbO8c364QC2Y8=
|
||||
github.com/timshannon/badgerhold v1.0.0/go.mod h1:Vv2Jj0PAfzqViEpGvJzLP8PY07x1iXLgKRuLY7bqPOE=
|
||||
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI=
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk=
|
||||
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
|
||||
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
|
||||
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0=
|
||||
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
|
||||
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
794
backend/services/mochi/hooks.go
Normal file
794
backend/services/mochi/hooks.go
Normal file
|
|
@ -0,0 +1,794 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co, thedevop
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
SetOptions byte = iota
|
||||
OnSysInfoTick
|
||||
OnStarted
|
||||
OnStopped
|
||||
OnConnectAuthenticate
|
||||
OnACLCheck
|
||||
OnConnect
|
||||
OnSessionEstablished
|
||||
OnDisconnect
|
||||
OnAuthPacket
|
||||
OnPacketRead
|
||||
OnPacketEncode
|
||||
OnPacketSent
|
||||
OnPacketProcessed
|
||||
OnSubscribe
|
||||
OnSubscribed
|
||||
OnSelectSubscribers
|
||||
OnUnsubscribe
|
||||
OnUnsubscribed
|
||||
OnPublish
|
||||
OnPublished
|
||||
OnPublishDropped
|
||||
OnRetainMessage
|
||||
OnQosPublish
|
||||
OnQosComplete
|
||||
OnQosDropped
|
||||
OnWill
|
||||
OnWillSent
|
||||
OnClientExpired
|
||||
OnRetainedExpired
|
||||
StoredClients
|
||||
StoredSubscriptions
|
||||
StoredInflightMessages
|
||||
StoredRetainedMessages
|
||||
StoredSysInfo
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidConfigType indicates a different Type of config value was expected to what was received.
|
||||
ErrInvalidConfigType = errors.New("invalid config type provided")
|
||||
)
|
||||
|
||||
// Hook provides an interface of handlers for different events which occur
|
||||
// during the lifecycle of the broker.
|
||||
type Hook interface {
|
||||
ID() string
|
||||
Provides(b byte) bool
|
||||
Init(config any) error
|
||||
Stop() error
|
||||
SetOpts(l *zerolog.Logger, o *HookOptions)
|
||||
OnStarted()
|
||||
OnStopped()
|
||||
OnConnectAuthenticate(cl *Client, pk packets.Packet) bool
|
||||
OnACLCheck(cl *Client, topic string, write bool) bool
|
||||
OnSysInfoTick(*system.Info)
|
||||
OnConnect(cl *Client, pk packets.Packet)
|
||||
OnSessionEstablished(cl *Client, pk packets.Packet)
|
||||
OnDisconnect(cl *Client, err error, expire bool)
|
||||
OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) // triggers when a new packet is received by a client, but before packet validation
|
||||
OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet // modify a packet before it is byte-encoded and written to the client
|
||||
OnPacketSent(cl *Client, pk packets.Packet, b []byte) // triggers when packet bytes have been written to the client
|
||||
OnPacketProcessed(cl *Client, pk packets.Packet, err error) // triggers after a packet from the client been processed (handled)
|
||||
OnSubscribe(cl *Client, pk packets.Packet) packets.Packet
|
||||
OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte)
|
||||
OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers
|
||||
OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet
|
||||
OnUnsubscribed(cl *Client, pk packets.Packet)
|
||||
OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error)
|
||||
OnPublished(cl *Client, pk packets.Packet)
|
||||
OnPublishDropped(cl *Client, pk packets.Packet)
|
||||
OnRetainMessage(cl *Client, pk packets.Packet, r int64)
|
||||
OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int)
|
||||
OnQosComplete(cl *Client, pk packets.Packet)
|
||||
OnQosDropped(cl *Client, pk packets.Packet)
|
||||
OnWill(cl *Client, will Will) (Will, error)
|
||||
OnWillSent(cl *Client, pk packets.Packet)
|
||||
OnClientExpired(cl *Client)
|
||||
OnRetainedExpired(filter string)
|
||||
StoredClients() ([]storage.Client, error)
|
||||
StoredSubscriptions() ([]storage.Subscription, error)
|
||||
StoredInflightMessages() ([]storage.Message, error)
|
||||
StoredRetainedMessages() ([]storage.Message, error)
|
||||
StoredSysInfo() (storage.SystemInfo, error)
|
||||
}
|
||||
|
||||
// HookOptions contains values which are inherited from the server on initialisation.
|
||||
type HookOptions struct {
|
||||
Capabilities *Capabilities
|
||||
}
|
||||
|
||||
// Hooks is a slice of Hook interfaces to be called in sequence.
|
||||
type Hooks struct {
|
||||
Log *zerolog.Logger // a logger for the hook (from the server)
|
||||
internal atomic.Value // a slice of []Hook
|
||||
wg sync.WaitGroup // a waitgroup for syncing hook shutdown
|
||||
qty int64 // the number of hooks in use
|
||||
sync.Mutex // a mutex for locking when adding hooks
|
||||
}
|
||||
|
||||
// Len returns the number of hooks added.
|
||||
func (h *Hooks) Len() int64 {
|
||||
return atomic.LoadInt64(&h.qty)
|
||||
}
|
||||
|
||||
// Provides returns true if any one hook provides any of the requested hook methods.
|
||||
func (h *Hooks) Provides(b ...byte) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
for _, hb := range b {
|
||||
if hook.Provides(hb) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Add adds and initializes a new hook.
|
||||
func (h *Hooks) Add(hook Hook, config any) error {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
err := hook.Init(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initialising %s hook: %w", hook.ID(), err)
|
||||
}
|
||||
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
i = []Hook{}
|
||||
}
|
||||
|
||||
i = append(i, hook)
|
||||
h.internal.Store(i)
|
||||
atomic.AddInt64(&h.qty, 1)
|
||||
h.wg.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAll returns a slice of all the hooks.
|
||||
func (h *Hooks) GetAll() []Hook {
|
||||
i, ok := h.internal.Load().([]Hook)
|
||||
if !ok {
|
||||
return []Hook{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// Stop indicates all attached hooks to gracefully end.
|
||||
func (h *Hooks) Stop() {
|
||||
go func() {
|
||||
for _, hook := range h.GetAll() {
|
||||
h.Log.Info().Str("hook", hook.ID()).Msg("stopping hook")
|
||||
if err := hook.Stop(); err != nil {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Msg("problem stopping hook")
|
||||
}
|
||||
|
||||
h.wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
h.wg.Wait()
|
||||
}
|
||||
|
||||
// OnSysInfoTick is called when the $SYS topic values are published out.
|
||||
func (h *Hooks) OnSysInfoTick(sys *system.Info) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSysInfoTick) {
|
||||
hook.OnSysInfoTick(sys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnStarted is called when the server has successfully started.
|
||||
func (h *Hooks) OnStarted() {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStarted) {
|
||||
hook.OnStarted()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnStopped is called when the server has successfully stopped.
|
||||
func (h *Hooks) OnStopped() {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnStopped) {
|
||||
hook.OnStopped()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *Hooks) OnConnect(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnect) {
|
||||
hook.OnConnect(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *Hooks) OnSessionEstablished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSessionEstablished) {
|
||||
hook.OnSessionEstablished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *Hooks) OnDisconnect(cl *Client, err error, expire bool) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnDisconnect) {
|
||||
hook.OnDisconnect(cl, err, expire)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a packet is received from a client.
|
||||
func (h *Hooks) OnPacketRead(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketRead) {
|
||||
npk, err := hook.OnPacketRead(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("packet rejected")
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnAuthPacket is called when an auth packet is received. It is intended to allow developers
|
||||
// to create their own auth packet handling mechanisms.
|
||||
func (h *Hooks) OnAuthPacket(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnAuthPacket) {
|
||||
npk, err := hook.OnAuthPacket(cl, pkx)
|
||||
if err != nil {
|
||||
return pk, err
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnPacketEncode is called immediately before a packet is encoded to be sent to a client.
|
||||
func (h *Hooks) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketEncode) {
|
||||
pk = hook.OnPacketEncode(cl, pk)
|
||||
}
|
||||
}
|
||||
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnPacketProcessed is called when a packet has been received and successfully handled by the broker.
|
||||
func (h *Hooks) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketProcessed) {
|
||||
hook.OnPacketProcessed(cl, pk, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet has been sent to a client. It takes a bytes parameter
|
||||
// containing the bytes sent.
|
||||
func (h *Hooks) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPacketSent) {
|
||||
hook.OnPacketSent(cl, pk, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribe is called when a client subscribes to one or more filters. This method
|
||||
// differs from OnSubscribed in that it allows you to modify the subscription values
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribe) {
|
||||
pk = hook.OnSubscribe(cl, pk)
|
||||
}
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *Hooks) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSubscribed) {
|
||||
hook.OnSubscribed(cl, pk, reasonCodes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnSelectSubscribers is called when subscribers have been collected for a topic, but before
|
||||
// shared subscription subscribers have been selected. This hook can be used to programmatically
|
||||
// remove or add clients to a publish to subscribers process, or to select the subscriber for a shared
|
||||
// group in a custom manner (such as based on client id, ip, etc).
|
||||
func (h *Hooks) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnSelectSubscribers) {
|
||||
subs = hook.OnSelectSubscribers(subs, pk)
|
||||
}
|
||||
}
|
||||
return subs
|
||||
}
|
||||
|
||||
// OnUnsubscribe is called when a client unsubscribes from one or more filters. This method
|
||||
// differs from OnUnsubscribed in that it allows you to modify the unsubscription values
|
||||
// before the packet is processed. The return values of the hook methods are passed-through
|
||||
// in the order the hooks were attached.
|
||||
func (h *Hooks) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribe) {
|
||||
pk = hook.OnUnsubscribe(cl, pk)
|
||||
}
|
||||
}
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *Hooks) OnUnsubscribed(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnUnsubscribed) {
|
||||
hook.OnUnsubscribed(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublish is called when a client publishes a message. This method differs from OnPublished
|
||||
// in that it allows you to modify you to modify the incoming packet before it is processed.
|
||||
// The return values of the hook methods are passed-through in the order the hooks were attached.
|
||||
func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, err error) {
|
||||
pkx = pk
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublish) {
|
||||
npk, err := hook.OnPublish(cl, pkx)
|
||||
if err != nil && errors.Is(err, packets.ErrRejectPacket) {
|
||||
h.Log.Debug().Err(err).Str("hook", hook.ID()).Interface("packet", pkx).Msg("publish packet rejected")
|
||||
return pk, err
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pkx = npk
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *Hooks) OnPublished(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublished) {
|
||||
hook.OnPublished(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnPublishDropped is called when a message to a client was dropped instead of delivered
|
||||
// such as when a client is too slow to respond.
|
||||
func (h *Hooks) OnPublishDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnPublishDropped) {
|
||||
hook.OnPublishDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *Hooks) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainMessage) {
|
||||
hook.OnRetainMessage(cl, pk, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos >= 1 is issued to a subscriber.
|
||||
// In other words, this method is called when a new inflight message is created or resent.
|
||||
// It is typically used to store a new inflight message.
|
||||
func (h *Hooks) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosPublish) {
|
||||
hook.OnQosPublish(cl, pk, sent, resends)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
// In other words, when an inflight message is resolved.
|
||||
// It is typically used to delete an inflight message from a store.
|
||||
func (h *Hooks) OnQosComplete(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosComplete) {
|
||||
hook.OnQosComplete(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires. In other words, when
|
||||
// an inflight message expires or is abandoned. It is typically used to delete an
|
||||
// inflight message from a store.
|
||||
func (h *Hooks) OnQosDropped(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnQosDropped) {
|
||||
hook.OnQosDropped(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message. This method
|
||||
// differs from OnWillSent in that it allows you to modify the LWT message before it is
|
||||
// published. The return values of the hook methods are passed-through in the order
|
||||
// the hooks were attached.
|
||||
func (h *Hooks) OnWill(cl *Client, will Will) Will {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWill) {
|
||||
mlwt, err := hook.OnWill(cl, will)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Interface("will", will).Msg("parse will error")
|
||||
continue
|
||||
}
|
||||
will = mlwt
|
||||
}
|
||||
}
|
||||
|
||||
return will
|
||||
}
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *Hooks) OnWillSent(cl *Client, pk packets.Packet) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnWillSent) {
|
||||
hook.OnWillSent(cl, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired is called when a client session has expired and should be deleted.
|
||||
func (h *Hooks) OnClientExpired(cl *Client) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnClientExpired) {
|
||||
hook.OnClientExpired(cl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when a retained message has expired and should be deleted.
|
||||
func (h *Hooks) OnRetainedExpired(filter string) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnRetainedExpired) {
|
||||
hook.OnRetainedExpired(filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all clients, e.g. from a persistent store, is used to
|
||||
// populate the server clients list before start.
|
||||
func (h *Hooks) StoredClients() (v []storage.Client, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredClients) {
|
||||
v, err := hook.StoredClients()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load clients")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all subcriptions, e.g. from a persistent store, and is
|
||||
// used to populate the server subscriptions list before start.
|
||||
func (h *Hooks) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSubscriptions) {
|
||||
v, err := hook.StoredSubscriptions()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load subscriptions")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all inflight messages, e.g. from a persistent store,
|
||||
// and is used to populate the restored clients with inflight messages before start.
|
||||
func (h *Hooks) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredInflightMessages) {
|
||||
v, err := hook.StoredInflightMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load inflight messages")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all retained messages, e.g. from a persistent store,
|
||||
// and is used to populate the server topics with retained messages before start.
|
||||
func (h *Hooks) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredRetainedMessages) {
|
||||
v, err := hook.StoredRetainedMessages()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load retained messages")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if len(v) > 0 {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *Hooks) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(StoredSysInfo) {
|
||||
v, err := hook.StoredSysInfo()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("hook", hook.ID()).Msg("failed to load $SYS info")
|
||||
return v, err
|
||||
}
|
||||
|
||||
if v.Version != "" {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
|
||||
// An implementation of this method MUST be used to allow or deny access to the
|
||||
// server (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check connecting users against an existing user database.
|
||||
func (h *Hooks) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnConnectAuthenticate) {
|
||||
if ok := hook.OnConnectAuthenticate(cl, pk); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck is called when a user attempts to publish or subscribe to a topic filter.
|
||||
// An implementation of this method MUST be used to allow or deny access to the
|
||||
// (see hooks/auth/allow_all or basic). It can be used in custom hooks to
|
||||
// check publishing and subscribing users against an existing permissions or roles database.
|
||||
func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
for _, hook := range h.GetAll() {
|
||||
if hook.Provides(OnACLCheck) {
|
||||
if ok := hook.OnACLCheck(cl, topic, write); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HookBase provides a set of default methods for each hook. It should be embedded in
|
||||
// all hooks.
|
||||
type HookBase struct {
|
||||
Hook
|
||||
Log *zerolog.Logger
|
||||
Opts *HookOptions
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *HookBase) ID() string {
|
||||
return "base"
|
||||
}
|
||||
|
||||
// Provides indicates which methods a hook provides. The default is none - this method
|
||||
// should be overridden by the embedding hook.
|
||||
func (h *HookBase) Provides(b byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Init performs any pre-start initializations for the hook, such as connecting to databases
|
||||
// or opening files.
|
||||
func (h *HookBase) Init(config any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOpts is called by the server to propagate internal values and generally should
|
||||
// not be called manually.
|
||||
func (h *HookBase) SetOpts(l *zerolog.Logger, opts *HookOptions) {
|
||||
h.Log = l
|
||||
h.Opts = opts
|
||||
}
|
||||
|
||||
// Stop is called to gracefully shutdown the hook.
|
||||
func (h *HookBase) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *HookBase) OnStarted() {}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *HookBase) OnStopped() {}
|
||||
|
||||
// OnSysInfoTick is called when the server publishes system info.
|
||||
func (h *HookBase) OnSysInfoTick(*system.Info) {}
|
||||
|
||||
// OnConnectAuthenticate is called when a user attempts to authenticate with the server.
|
||||
func (h *HookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck is called when a user attempts to subscribe or publish to a topic.
|
||||
func (h *HookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnConnect is called when a new client connects.
|
||||
func (h *HookBase) OnConnect(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnSessionEstablished is called when a new client establishes a session (after OnConnect).
|
||||
func (h *HookBase) OnSessionEstablished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnDisconnect is called when a client is disconnected for any reason.
|
||||
func (h *HookBase) OnDisconnect(cl *Client, err error, expire bool) {}
|
||||
|
||||
// OnAuthPacket is called when an auth packet is received from the client.
|
||||
func (h *HookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a packet is received.
|
||||
func (h *HookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketEncode is called before a packet is byte-encoded and written to the client.
|
||||
func (h *HookBase) OnPacketEncode(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnPacketSent is called immediately after a packet is written to a client.
|
||||
func (h *HookBase) OnPacketSent(cl *Client, pk packets.Packet, b []byte) {}
|
||||
|
||||
// OnPacketProcessed is called immediately after a packet from a client is processed.
|
||||
func (h *HookBase) OnPacketProcessed(cl *Client, pk packets.Packet, err error) {}
|
||||
|
||||
// OnSubscribe is called when a client subscribes to one or more filters.
|
||||
func (h *HookBase) OnSubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnSubscribed is called when a client subscribes to one or more filters.
|
||||
func (h *HookBase) OnSubscribed(cl *Client, pk packets.Packet, reasonCodes []byte) {}
|
||||
|
||||
// OnSelectSubscribers is called when selecting subscribers to receive a message.
|
||||
func (h *HookBase) OnSelectSubscribers(subs *Subscribers, pk packets.Packet) *Subscribers {
|
||||
return subs
|
||||
}
|
||||
|
||||
// OnUnsubscribe is called when a client unsubscribes from one or more filters.
|
||||
func (h *HookBase) OnUnsubscribe(cl *Client, pk packets.Packet) packets.Packet {
|
||||
return pk
|
||||
}
|
||||
|
||||
// OnUnsubscribed is called when a client unsubscribes from one or more filters.
|
||||
func (h *HookBase) OnUnsubscribed(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublish is called when a client publishes a message.
|
||||
func (h *HookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPublished is called when a client has published a message to subscribers.
|
||||
func (h *HookBase) OnPublished(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnPublishDropped is called when a message to a client is dropped instead of being delivered.
|
||||
func (h *HookBase) OnPublishDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnRetainMessage is called then a published message is retained.
|
||||
func (h *HookBase) OnRetainMessage(cl *Client, pk packets.Packet, r int64) {}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos > 1 is issued to a subscriber.
|
||||
func (h *HookBase) OnQosPublish(cl *Client, pk packets.Packet, sent int64, resends int) {}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *HookBase) OnQosComplete(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *HookBase) OnQosDropped(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnWill is called when a client disconnects and publishes an LWT message.
|
||||
func (h *HookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
return will, nil
|
||||
}
|
||||
|
||||
// OnWillSent is called when an LWT message has been issued from a disconnecting client.
|
||||
func (h *HookBase) OnWillSent(cl *Client, pk packets.Packet) {}
|
||||
|
||||
// OnClientExpired is called when a client session has expired.
|
||||
func (h *HookBase) OnClientExpired(cl *Client) {}
|
||||
|
||||
// OnRetainedExpired is called when a retained message for a topic has expired.
|
||||
func (h *HookBase) OnRetainedExpired(topic string) {}
|
||||
|
||||
// StoredClients returns all clients from a store.
|
||||
func (h *HookBase) StoredClients() (v []storage.Client, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all subcriptions from a store.
|
||||
func (h *HookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all inflight messages from a store.
|
||||
func (h *HookBase) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all retained messages from a store.
|
||||
func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// StoredSysInfo returns a set of system info values.
|
||||
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
return
|
||||
}
|
||||
41
backend/services/mochi/hooks/auth/allow_all.go
Normal file
41
backend/services/mochi/hooks/auth/allow_all.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
// AllowHook is an authentication hook which allows connection access
|
||||
// for all users and read and write access to all topics.
|
||||
type AllowHook struct {
|
||||
mqtt.HookBase
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *AllowHook) ID() string {
|
||||
return "allow-all-auth"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *AllowHook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate returns true/allowed for all requests.
|
||||
func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// OnACLCheck returns true/allowed for all checks.
|
||||
func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
35
backend/services/mochi/hooks/auth/allow_all_test.go
Normal file
35
backend/services/mochi/hooks/auth/allow_all_test.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllowAllID(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.Equal(t, "allow-all-auth", h.ID())
|
||||
}
|
||||
|
||||
func TestAllowAllProvides(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
require.False(t, h.Provides(mqtt.OnPublished))
|
||||
}
|
||||
|
||||
func TestAllowAllOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{}))
|
||||
}
|
||||
|
||||
func TestAllowAllOnACLCheck(t *testing.T) {
|
||||
h := new(AllowHook)
|
||||
require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true))
|
||||
}
|
||||
107
backend/services/mochi/hooks/auth/auth.go
Normal file
107
backend/services/mochi/hooks/auth/auth.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
// Options contains the configuration/rules data for the auth ledger.
|
||||
type Options struct {
|
||||
Data []byte
|
||||
Ledger *Ledger
|
||||
}
|
||||
|
||||
// Hook is an authentication hook which implements an auth ledger.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
ledger *Ledger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "auth-ledger"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnConnectAuthenticate,
|
||||
mqtt.OnACLCheck,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init configures the hook with the auth ledger to be used for checking.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
|
||||
var err error
|
||||
if h.config.Ledger != nil {
|
||||
h.ledger = h.config.Ledger
|
||||
} else if len(h.config.Data) > 0 {
|
||||
h.ledger = new(Ledger)
|
||||
err = h.ledger.Unmarshal(h.config.Data)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.ledger == nil {
|
||||
h.ledger = &Ledger{
|
||||
Auth: AuthRules{},
|
||||
ACL: ACLRules{},
|
||||
}
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Int("authentication", len(h.ledger.Auth)).
|
||||
Int("acl", len(h.ledger.ACL)).
|
||||
Msg("loaded auth rules")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnConnectAuthenticate returns true if the connecting client has rules which provide access
|
||||
// in the auth ledger.
|
||||
func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {
|
||||
if _, ok := h.ledger.AuthOk(cl, pk); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("username", string(pk.Connect.Username)).
|
||||
Str("remote", cl.Net.Remote).
|
||||
Msg("client failed authentication check")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OnACLCheck returns true if the connecting client has matching read or write access to subscribe
|
||||
// or publish to a given topic.
|
||||
func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool {
|
||||
if _, ok := h.ledger.ACLOk(cl, topic, write); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
h.Log.Debug().
|
||||
Str("client", cl.ID).
|
||||
Str("username", string(cl.Properties.Username)).
|
||||
Str("topic", topic).
|
||||
Msg("client failed allowed ACL check")
|
||||
|
||||
return false
|
||||
}
|
||||
213
backend/services/mochi/hooks/auth/auth_test.go
Normal file
213
backend/services/mochi/hooks/auth/auth_test.go
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
// func teardown(t *testing.T, path string, h *Hook) {
|
||||
// h.Stop()
|
||||
// }
|
||||
|
||||
func TestBasicID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "auth-ledger", h.ID())
|
||||
}
|
||||
|
||||
func TestBasicProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.True(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
require.False(t, h.Provides(mqtt.OnPublish))
|
||||
}
|
||||
|
||||
func TestBasicInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBasicInitDefaultConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerPointer(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
ln := &Ledger{
|
||||
Auth: []AuthRule{
|
||||
{
|
||||
Remote: "127.0.0.1",
|
||||
Allow: true,
|
||||
},
|
||||
},
|
||||
ACL: []ACLRule{
|
||||
{
|
||||
Remote: "127.0.0.1",
|
||||
Filters: Filters{
|
||||
"#": ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := h.Init(&Options{
|
||||
Ledger: ln,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Same(t, ln, h.ledger)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerJSON(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: ledgerJSON,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
|
||||
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerYAML(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: ledgerYAML,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username)
|
||||
require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client)
|
||||
}
|
||||
|
||||
func TestBasicInitWithLedgerBadDAta(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
require.Nil(t, h.ledger)
|
||||
err := h.Init(&Options{
|
||||
Data: []byte("fdsfdsafasd"),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
ln.ACL = checkLedger.ACL
|
||||
err := h.Init(
|
||||
&Options{
|
||||
Ledger: ln,
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
))
|
||||
|
||||
require.False(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
))
|
||||
|
||||
require.False(t, h.OnConnectAuthenticate(
|
||||
&mqtt.Client{},
|
||||
packets.Packet{},
|
||||
))
|
||||
}
|
||||
|
||||
func TestOnACL(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
ln := new(Ledger)
|
||||
ln.Auth = checkLedger.Auth
|
||||
ln.ACL = checkLedger.ACL
|
||||
err := h.Init(
|
||||
&Options{
|
||||
Ledger: ln,
|
||||
},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"mochi/info",
|
||||
true,
|
||||
))
|
||||
|
||||
require.False(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"d/j/f",
|
||||
true,
|
||||
))
|
||||
|
||||
require.True(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"readonly",
|
||||
false,
|
||||
))
|
||||
|
||||
require.False(t, h.OnACLCheck(
|
||||
&mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
"readonly",
|
||||
true,
|
||||
))
|
||||
}
|
||||
231
backend/services/mochi/hooks/auth/ledger.go
Normal file
231
backend/services/mochi/hooks/auth/ledger.go
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
Deny Access = iota // user cannot access the topic
|
||||
ReadOnly // user can only subscribe to the topic
|
||||
WriteOnly // user can only publish to the topic
|
||||
ReadWrite // user can both publish and subscribe to the topic
|
||||
)
|
||||
|
||||
// Access determines the read/write privileges for an ACL rule.
|
||||
type Access byte
|
||||
|
||||
// Users contains a map of access rules for specific users, keyed on username.
|
||||
type Users map[string]UserRule
|
||||
|
||||
// UserRule defines a set of access rules for a specific user.
|
||||
type UserRule struct {
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
|
||||
ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired
|
||||
Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user
|
||||
}
|
||||
|
||||
// AuthRules defines generic access rules applicable to all users.
|
||||
type AuthRules []AuthRule
|
||||
|
||||
type AuthRule struct {
|
||||
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
|
||||
Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user
|
||||
Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users
|
||||
}
|
||||
|
||||
// ACLRules defines generic topic or filter access rules applicable to all users.
|
||||
type ACLRules []ACLRule
|
||||
|
||||
// ACLRule defines access rules for a specific topic or filter.
|
||||
type ACLRule struct {
|
||||
Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client
|
||||
Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user
|
||||
Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or
|
||||
Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match
|
||||
}
|
||||
|
||||
// Filters is a map of Access rules keyed on filter.
|
||||
type Filters map[RString]Access
|
||||
|
||||
// RString is a rule value string.
|
||||
type RString string
|
||||
|
||||
// Matches returns true if the rule matches a given string.
|
||||
func (r RString) Matches(a string) bool {
|
||||
rr := string(r)
|
||||
if r == "" || r == "*" || a == rr {
|
||||
return true
|
||||
}
|
||||
|
||||
i := strings.Index(rr, "*")
|
||||
if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterMatches returns true if a filter matches a topic rule.
|
||||
func (f RString) FilterMatches(a string) bool {
|
||||
_, ok := MatchTopic(string(f), a)
|
||||
return ok
|
||||
}
|
||||
|
||||
// MatchTopic checks if a given topic matches a filter, accounting for filter
|
||||
// wildcards. Eg. filter /a/b/+/c == topic a/b/d/c.
|
||||
func MatchTopic(filter string, topic string) (elements []string, matched bool) {
|
||||
filterParts := strings.Split(filter, "/")
|
||||
topicParts := strings.Split(topic, "/")
|
||||
|
||||
elements = make([]string, 0)
|
||||
for i := 0; i < len(filterParts); i++ {
|
||||
if i >= len(topicParts) {
|
||||
matched = false
|
||||
return
|
||||
}
|
||||
|
||||
if filterParts[i] == "+" {
|
||||
elements = append(elements, topicParts[i])
|
||||
continue
|
||||
}
|
||||
|
||||
if filterParts[i] == "#" {
|
||||
matched = true
|
||||
elements = append(elements, strings.Join(topicParts[i:], "/"))
|
||||
return
|
||||
}
|
||||
|
||||
if filterParts[i] != topicParts[i] {
|
||||
matched = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return elements, true
|
||||
}
|
||||
|
||||
// Ledger is an auth ledger containing access rules for users and topics.
|
||||
type Ledger struct {
|
||||
sync.Mutex `json:"-" yaml:"-"`
|
||||
Users Users `json:"users" yaml:"users"`
|
||||
Auth AuthRules `json:"auth" yaml:"auth"`
|
||||
ACL ACLRules `json:"acl" yaml:"acl"`
|
||||
}
|
||||
|
||||
// Update updates the internal values of the ledger.
|
||||
func (l *Ledger) Update(ln *Ledger) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Auth = ln.Auth
|
||||
l.ACL = ln.ACL
|
||||
}
|
||||
|
||||
// AuthOk returns true if the rules indicate the user is allowed to authenticate.
|
||||
func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
if l.Users != nil {
|
||||
if u, ok := l.Users[string(cl.Properties.Username)]; ok &&
|
||||
u.Password != "" &&
|
||||
u.Password == RString(pk.Connect.Password) {
|
||||
return 0, !u.Disallow
|
||||
}
|
||||
}
|
||||
|
||||
// If there's no users map, or no user was found, attempt to find a matching
|
||||
// rule (which may also contain a user).
|
||||
for n, rule := range l.Auth {
|
||||
if rule.Client.Matches(cl.ID) &&
|
||||
rule.Username.Matches(string(cl.Properties.Username)) &&
|
||||
rule.Password.Matches(string(pk.Connect.Password)) &&
|
||||
rule.Remote.Matches(cl.Net.Remote) {
|
||||
return n, rule.Allow
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// ACLOk returns true if the rules indicate the user is allowed to read or write to
|
||||
// a specific filter or topic respectively, based on the write bool.
|
||||
func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) {
|
||||
// If the users map is set, always check for a predefined user first instead
|
||||
// of iterating through global rules.
|
||||
if l.Users != nil {
|
||||
if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 {
|
||||
for filter, access := range u.ACL {
|
||||
if filter.FilterMatches(topic) {
|
||||
if !write && (access == ReadOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else if write && (access == WriteOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for n, rule := range l.ACL {
|
||||
if rule.Client.Matches(cl.ID) &&
|
||||
rule.Username.Matches(string(cl.Properties.Username)) &&
|
||||
rule.Remote.Matches(cl.Net.Remote) {
|
||||
if len(rule.Filters) == 0 {
|
||||
return n, true
|
||||
}
|
||||
|
||||
for filter, access := range rule.Filters {
|
||||
if filter.FilterMatches(topic) {
|
||||
if !write && (access == ReadOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else if write && (access == WriteOnly || access == ReadWrite) {
|
||||
return n, true
|
||||
} else {
|
||||
return n, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, true
|
||||
}
|
||||
|
||||
// ToJSON encodes the values into a JSON string.
|
||||
func (l *Ledger) ToJSON() (data []byte, err error) {
|
||||
return json.Marshal(l)
|
||||
}
|
||||
|
||||
// ToYAML encodes the values into a YAML string.
|
||||
func (l *Ledger) ToYAML() (data []byte, err error) {
|
||||
return yaml.Marshal(l)
|
||||
}
|
||||
|
||||
// Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct.
|
||||
func (l *Ledger) Unmarshal(data []byte) error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if data[0] == '{' {
|
||||
return json.Unmarshal(data, l)
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, &l)
|
||||
}
|
||||
610
backend/services/mochi/hooks/auth/ledger_test.go
Normal file
610
backend/services/mochi/hooks/auth/ledger_test.go
Normal file
|
|
@ -0,0 +1,610 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
checkLedger = Ledger{
|
||||
Users: Users{ // users are allowed by default
|
||||
"mochi-co": {
|
||||
Password: "melon",
|
||||
ACL: Filters{
|
||||
"d/+/f": Deny,
|
||||
"mochi-co/#": ReadWrite,
|
||||
"readonly": ReadOnly,
|
||||
},
|
||||
},
|
||||
"suspended-username": {
|
||||
Password: "any",
|
||||
Disallow: true,
|
||||
},
|
||||
"mochi": { // ACL only, will defer to AuthRules for authentication
|
||||
ACL: Filters{
|
||||
"special/mochi": ReadOnly,
|
||||
"secret/mochi": Deny,
|
||||
"ignored": ReadWrite,
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthRules{
|
||||
{Username: "banned-user"}, // never allow specific username
|
||||
{Remote: "127.0.0.1", Allow: true}, // always allow localhost
|
||||
{Remote: "123.123.123.123"}, // disallow any from specific address
|
||||
{Username: "not-mochi", Remote: "111.144.155.166"}, // disallow specific username and address
|
||||
{Remote: "111.*", Allow: true}, // allow any in wildcard (that isn't the above username)
|
||||
{Username: "mochi", Password: "melon", Allow: true}, // allow matching user/pass
|
||||
{Username: "mochi-co", Password: "melon", Allow: false}, // allow matching user/pass (should never trigger due to Users map)
|
||||
},
|
||||
ACL: ACLRules{
|
||||
{
|
||||
Username: "mochi", // allow matching user/pass
|
||||
Filters: Filters{
|
||||
"a/b/c": Deny,
|
||||
"d/+/f": Deny,
|
||||
"mochi/#": ReadWrite,
|
||||
"updates/#": WriteOnly,
|
||||
"readonly": ReadOnly,
|
||||
"ignored": Deny,
|
||||
},
|
||||
},
|
||||
{Remote: "localhost", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to localhost
|
||||
{Username: "admin", Filters: Filters{"$SYS/#": ReadOnly}}, // allow $SYS access to admin
|
||||
{Remote: "001.002.003.004"}, // Allow all with no filter
|
||||
{Filters: Filters{"$SYS/#": Deny}}, // Deny $SYS access to all others
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestRStringMatches(t *testing.T) {
|
||||
require.True(t, RString("*").Matches("any"))
|
||||
require.True(t, RString("*").Matches(""))
|
||||
require.True(t, RString("").Matches("any"))
|
||||
require.True(t, RString("").Matches(""))
|
||||
require.False(t, RString("no").Matches("any"))
|
||||
require.False(t, RString("no").Matches(""))
|
||||
}
|
||||
|
||||
func TestCanAuthenticate(t *testing.T) {
|
||||
tt := []struct {
|
||||
desc string
|
||||
client *mqtt.Client
|
||||
pk packets.Packet
|
||||
n int
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "allow all local 127.0.0.1",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{}},
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 5,
|
||||
},
|
||||
{
|
||||
desc: "deny username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "allow all local 127.0.0.1",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 5,
|
||||
},
|
||||
{
|
||||
desc: "deny username/password",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "deny client from address",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("not-mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "111.144.155.166",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: false,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
desc: "allow remote wildcard",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "111.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: true,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
desc: "never allow username",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("banned-user"),
|
||||
},
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "matching user in users",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi-co"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}},
|
||||
ok: true,
|
||||
n: 0,
|
||||
},
|
||||
{
|
||||
desc: "never user in users",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("suspended-user"),
|
||||
},
|
||||
},
|
||||
pk: packets.Packet{Connect: packets.ConnectParams{Password: []byte("any")}},
|
||||
ok: false,
|
||||
n: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range tt {
|
||||
t.Run(d.desc, func(t *testing.T) {
|
||||
n, ok := checkLedger.AuthOk(d.client, d.pk)
|
||||
require.Equal(t, d.n, n)
|
||||
require.Equal(t, d.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanACL(t *testing.T) {
|
||||
tt := []struct {
|
||||
client *mqtt.Client
|
||||
desc string
|
||||
topic string
|
||||
n int
|
||||
write bool
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "allow normal write on any other filter",
|
||||
client: &mqtt.Client{},
|
||||
topic: "default/acl/write/access",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow normal read on any other filter",
|
||||
client: &mqtt.Client{},
|
||||
topic: "default/acl/read/access",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny user on literal filter",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "a/b/c",
|
||||
},
|
||||
{
|
||||
desc: "deny user on partial filter",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "d/j/f",
|
||||
},
|
||||
{
|
||||
desc: "allow read/write to user path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "mochi/read/write",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny read on write-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/no/reading",
|
||||
write: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "deny read on write-only path ext",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/mochi",
|
||||
write: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "allow read on not-acl path (no #)",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow write on write-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "updates/mochi",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "deny write on read-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "readonly",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "allow read on read-only path",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "readonly",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "allow $sys access to localhost",
|
||||
client: &mqtt.Client{
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "localhost",
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: true,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
desc: "allow $sys access to admin",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("admin"),
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: true,
|
||||
n: 2,
|
||||
},
|
||||
{
|
||||
desc: "deny $sys access to all others",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "$SYS/test",
|
||||
write: false,
|
||||
ok: false,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
desc: "allow all with no filter",
|
||||
client: &mqtt.Client{
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "001.002.003.004",
|
||||
},
|
||||
},
|
||||
topic: "any/path",
|
||||
write: true,
|
||||
ok: true,
|
||||
n: 3,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl deny",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "secret/mochi",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl any",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "any/mochi",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl write on read-only",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "special/mochi",
|
||||
write: true,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
desc: "use users embedded acl read on read-only",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "special/mochi",
|
||||
write: false,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "preference users embedded acl",
|
||||
client: &mqtt.Client{
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("mochi"),
|
||||
},
|
||||
},
|
||||
topic: "ignored",
|
||||
write: true,
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range tt {
|
||||
t.Run(d.desc, func(t *testing.T) {
|
||||
n, ok := checkLedger.ACLOk(d.client, d.topic, d.write)
|
||||
require.Equal(t, d.n, n)
|
||||
require.Equal(t, d.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchTopic(t *testing.T) {
|
||||
el, matched := MatchTopic("a/+/c/+", "a/b/c/d")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "d"}, el)
|
||||
|
||||
el, matched = MatchTopic("a/+/+/+", "a/b/c/d")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "c", "d"}, el)
|
||||
|
||||
el, matched = MatchTopic("stuff/#", "stuff/things/yeah")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"things/yeah"}, el)
|
||||
|
||||
el, matched = MatchTopic("a/+/#/+", "a/b/c/d/as/dds")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, []string{"b", "c/d/as/dds"}, el)
|
||||
|
||||
el, matched = MatchTopic("test", "test")
|
||||
require.True(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic("things/stuff//", "things/stuff/")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic("t", "t2")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
|
||||
el, matched = MatchTopic(" ", " ")
|
||||
require.False(t, matched)
|
||||
require.Equal(t, make([]string, 0), el)
|
||||
}
|
||||
|
||||
var (
|
||||
ledgerStruct = Ledger{
|
||||
Users: Users{
|
||||
"mochi": {
|
||||
Password: "peach",
|
||||
ACL: Filters{
|
||||
"readonly": ReadOnly,
|
||||
"deny": Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
Auth: AuthRules{
|
||||
{
|
||||
Client: "*",
|
||||
Username: "mochi-co",
|
||||
Password: "melon",
|
||||
Remote: "192.168.1.*",
|
||||
Allow: true,
|
||||
},
|
||||
},
|
||||
ACL: ACLRules{
|
||||
{
|
||||
Client: "*",
|
||||
Username: "mochi-co",
|
||||
Remote: "127.*",
|
||||
Filters: Filters{
|
||||
"readonly": ReadOnly,
|
||||
"writeonly": WriteOnly,
|
||||
"readwrite": ReadWrite,
|
||||
"deny": Deny,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ledgerJSON = []byte(`{"users":{"mochi":{"password":"peach","acl":{"deny":0,"readonly":1}}},"auth":[{"client":"*","username":"mochi-co","remote":"192.168.1.*","password":"melon","allow":true}],"acl":[{"client":"*","username":"mochi-co","remote":"127.*","filters":{"deny":0,"readonly":1,"readwrite":3,"writeonly":2}}]}`)
|
||||
ledgerYAML = []byte(`users:
|
||||
mochi:
|
||||
password: peach
|
||||
acl:
|
||||
deny: 0
|
||||
readonly: 1
|
||||
auth:
|
||||
- client: '*'
|
||||
username: mochi-co
|
||||
remote: 192.168.1.*
|
||||
password: melon
|
||||
allow: true
|
||||
acl:
|
||||
- client: '*'
|
||||
username: mochi-co
|
||||
remote: 127.*
|
||||
filters:
|
||||
deny: 0
|
||||
readonly: 1
|
||||
readwrite: 3
|
||||
writeonly: 2
|
||||
`)
|
||||
)
|
||||
|
||||
func TestLedgerUpdate(t *testing.T) {
|
||||
old := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
new := &Ledger{
|
||||
Auth: AuthRules{
|
||||
{Remote: "127.0.0.1", Allow: true},
|
||||
{Remote: "192.168.*", Allow: true},
|
||||
},
|
||||
}
|
||||
|
||||
old.Update(new)
|
||||
require.Len(t, old.Auth, 2)
|
||||
require.Equal(t, RString("192.168.*"), old.Auth[1].Remote)
|
||||
require.NotSame(t, new, old)
|
||||
}
|
||||
|
||||
func TestLedgerToJSON(t *testing.T) {
|
||||
data, err := ledgerStruct.ToJSON()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerJSON, data)
|
||||
}
|
||||
|
||||
func TestLedgerToYAML(t *testing.T) {
|
||||
data, err := ledgerStruct.ToYAML()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ledgerYAML, data)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalFromYAML(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal(ledgerYAML)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &ledgerStruct, l)
|
||||
require.NotSame(t, l, &ledgerStruct)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalFromJSON(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal(ledgerJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &ledgerStruct, l)
|
||||
require.NotSame(t, l, &ledgerStruct)
|
||||
}
|
||||
|
||||
func TestLedgerUnmarshalNil(t *testing.T) {
|
||||
l := new(Ledger)
|
||||
err := l.Unmarshal([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, new(Ledger), l)
|
||||
}
|
||||
250
backend/services/mochi/hooks/debug/debug.go
Normal file
250
backend/services/mochi/hooks/debug/debug.go
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Options contains configuration settings for the debug output.
|
||||
type Options struct {
|
||||
ShowPacketData bool // include decoded packet data (default false)
|
||||
ShowPings bool // show ping requests and responses (default false)
|
||||
ShowPasswords bool // show connecting user passwords (default false)
|
||||
}
|
||||
|
||||
// Hook is a debugging hook which logs additional low-level information from the server.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options
|
||||
Log *zerolog.Logger
|
||||
}
|
||||
|
||||
// ID returns the ID of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "debug"
|
||||
}
|
||||
|
||||
// Provides indicates that this hook provides all methods.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init is called when the hook is initialized.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOpts is called when the hook receives inheritable server parameters.
|
||||
func (h *Hook) SetOpts(l *zerolog.Logger, opts *mqtt.HookOptions) {
|
||||
h.Log = l
|
||||
h.Log.Debug().Interface("opts", opts).Str("method", "SetOpts").Send()
|
||||
}
|
||||
|
||||
// Stop is called when the hook is stopped.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Debug().Str("method", "Stop").Send()
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStarted is called when the server starts.
|
||||
func (h *Hook) OnStarted() {
|
||||
h.Log.Debug().Str("method", "OnStarted").Send()
|
||||
}
|
||||
|
||||
// OnStopped is called when the server stops.
|
||||
func (h *Hook) OnStopped() {
|
||||
h.Log.Debug().Str("method", "OnStopped").Send()
|
||||
}
|
||||
|
||||
// OnPacketRead is called when a new packet is received from a client.
|
||||
func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet is sent to a client.
|
||||
func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) {
|
||||
if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings {
|
||||
return
|
||||
}
|
||||
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID)
|
||||
}
|
||||
|
||||
// OnRetainMessage is called when a published message is retained (or retain deleted/modified).
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("retained message on topic")
|
||||
}
|
||||
|
||||
// OnQosPublish is called when a publish packet with Qos is issued to a subscriber.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight out")
|
||||
}
|
||||
|
||||
// OnQosComplete is called when the Qos flow for a message has been completed.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight complete")
|
||||
}
|
||||
|
||||
// OnQosDropped is called the Qos flow for a message expires.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Interface("m", h.packetMeta(pk)).Msgf("inflight dropped")
|
||||
}
|
||||
|
||||
// OnLWTSent is called when a will message has been issued from a disconnecting client.
|
||||
func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.Log.Debug().Str("method", "OnLWTSent").Str("client", cl.ID).Msg("sent lwt for client")
|
||||
}
|
||||
|
||||
// OnRetainedExpired is called when the server clears expired retained messages.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
h.Log.Debug().Str("method", "OnRetainedExpired").Str("topic", filter).Msg("retained message expired")
|
||||
}
|
||||
|
||||
// OnClientExpired is called when the server clears an expired client.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
h.Log.Debug().Str("method", "OnClientExpired").Str("client", cl.ID).Msg("client session expired")
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores clients from a store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores subscriptions from a store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredSubscriptions").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores retained messages from a store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredRetainedMessages").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores inflight messages from a store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredInflightMessages").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredClients is called when the server restores system info from a store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
h.Log.Debug().
|
||||
Str("method", "StoredClients").
|
||||
Send()
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// packetMeta adds additional type-specific metadata to the debug logs.
|
||||
func (h *Hook) packetMeta(pk packets.Packet) map[string]any {
|
||||
m := map[string]any{}
|
||||
switch pk.FixedHeader.Type {
|
||||
case packets.Connect:
|
||||
m["id"] = pk.Connect.ClientIdentifier
|
||||
m["clean"] = pk.Connect.Clean
|
||||
m["keepalive"] = pk.Connect.Keepalive
|
||||
m["version"] = pk.ProtocolVersion
|
||||
m["username"] = string(pk.Connect.Username)
|
||||
if h.config.ShowPasswords {
|
||||
m["password"] = string(pk.Connect.Password)
|
||||
}
|
||||
if pk.Connect.WillFlag {
|
||||
m["will_topic"] = pk.Connect.WillTopic
|
||||
m["will_payload"] = string(pk.Connect.WillPayload)
|
||||
}
|
||||
case packets.Publish:
|
||||
m["topic"] = pk.TopicName
|
||||
m["payload"] = string(pk.Payload)
|
||||
m["raw"] = pk.Payload
|
||||
m["qos"] = pk.FixedHeader.Qos
|
||||
m["id"] = pk.PacketID
|
||||
case packets.Connack:
|
||||
fallthrough
|
||||
case packets.Disconnect:
|
||||
fallthrough
|
||||
case packets.Puback:
|
||||
fallthrough
|
||||
case packets.Pubrec:
|
||||
fallthrough
|
||||
case packets.Pubrel:
|
||||
fallthrough
|
||||
case packets.Pubcomp:
|
||||
m["id"] = pk.PacketID
|
||||
m["reason"] = int(pk.ReasonCode)
|
||||
if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 {
|
||||
m["reason_string"] = pk.Properties.ReasonString
|
||||
}
|
||||
case packets.Subscribe:
|
||||
f := map[string]int{}
|
||||
ids := map[string]int{}
|
||||
for _, v := range pk.Filters {
|
||||
f[v.Filter] = int(v.Qos)
|
||||
ids[v.Filter] = v.Identifier
|
||||
}
|
||||
m["filters"] = f
|
||||
m["subids"] = f
|
||||
|
||||
case packets.Unsubscribe:
|
||||
f := []string{}
|
||||
for _, v := range pk.Filters {
|
||||
f = append(f, v.Filter)
|
||||
}
|
||||
m["filters"] = f
|
||||
case packets.Suback:
|
||||
fallthrough
|
||||
case packets.Unsuback:
|
||||
r := []int{}
|
||||
for _, v := range pk.ReasonCodes {
|
||||
r = append(r, int(v))
|
||||
}
|
||||
m["reasons"] = r
|
||||
case packets.Auth:
|
||||
// tbd
|
||||
}
|
||||
|
||||
if h.config.ShowPacketData {
|
||||
m["packet"] = pk
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
473
backend/services/mochi/hooks/storage/badger/badger.go
Normal file
473
backend/services/mochi/hooks/storage/badger/badger.go
Normal file
|
|
@ -0,0 +1,473 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the badger db file.
|
||||
defaultDbFile = ".badger"
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the BadgerDB instance.
|
||||
type Options struct {
|
||||
Options *badgerhold.Options
|
||||
Path string
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using BadgerDB file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the BadgerDB instance.
|
||||
db *badgerhold.Store // the BadgerDB instance.
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "badger-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init initializes and connects to the badger instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.Path == "" {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
options := badgerhold.DefaultOptions
|
||||
options.Dir = h.config.Path
|
||||
options.ValueDir = h.config.Path
|
||||
options.Logger = h
|
||||
|
||||
var err error
|
||||
h.db, err = badgerhold.Open(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the badger instance.
|
||||
func (h *Hook) Stop() error {
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if their session has expired.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
h.updateClient(cl)
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", clientKey(cl)).Msg("failed to delete client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.Delete(subscriptionKey(cl, pk.Filters[i].Filter), new(storage.Subscription))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", subscriptionKey(cl, pk.Filters[i].Filter)).Msg("failed to delete subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.Delete(retainedKey(pk.TopicName), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", retainedKey(pk.TopicName)).Msg("failed to delete retained message data")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
PacketID: pk.PacketID,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert qos inflight data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(inflightKey(cl, pk), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", inflightKey(cl, pk)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.Upsert(in.ID, in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to upsert $SYS data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(retainedKey(filter), new(storage.Message))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete expired retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.Delete(clientKey(cl), new(storage.Client))
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client data")
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.ClientKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.SubscriptionKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.RetainedKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find(&v, badgerhold.Where("T").Eq(storage.InflightKey))
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Get(storage.SysInfoKey, &v)
|
||||
if err != nil && !errors.Is(err, badgerhold.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Errorf satisfies the badger interface for an error logger.
|
||||
func (h *Hook) Errorf(m string, v ...interface{}) {
|
||||
h.Log.Error().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
||||
// Warningf satisfies the badger interface for a warning logger.
|
||||
func (h *Hook) Warningf(m string, v ...interface{}) {
|
||||
h.Log.Warn().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
||||
// Infof satisfies the badger interface for an info logger.
|
||||
func (h *Hook) Infof(m string, v ...interface{}) {
|
||||
h.Log.Info().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
|
||||
// Debugf satisfies the badger interface for a debug logger.
|
||||
func (h *Hook) Debugf(m string, v ...interface{}) {
|
||||
h.Log.Debug().Interface("v", v).Msgf(strings.ToLower(strings.Trim(m, "\n")), v...)
|
||||
}
|
||||
681
backend/services/mochi/hooks/storage/badger/badger_test.go
Normal file
681
backend/services/mochi/hooks/storage/badger/badger_test.go
Normal file
|
|
@ -0,0 +1,681 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package badger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/timshannon/badgerhold"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
h.db.Badger().Close()
|
||||
err := os.RemoveAll("./" + strings.Replace(path, "..", "", -1))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "badger-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.db.Upsert(clientKey, &storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.Get(clientKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.db.Get(clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.Get(clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.db.Get(subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, badgerhold.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.Get(retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.Get(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.Get(retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.db.Upsert(m.ID, m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.Get(m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.db.Get(m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.Get(inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.db.Get(inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, badgerhold.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.db.Get(storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.db.Upsert("cl1", &storage.Client{ID: "cl1", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("cl2", &storage.Client{ID: "cl2", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("cl3", &storage.Client{ID: "cl3", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.db.Upsert("sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("m2", &storage.Message{ID: "m2", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("m3", &storage.Message{ID: "m3", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Upsert("i1", &storage.Message{ID: "i1", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("i2", &storage.Message{ID: "i2", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("i3", &storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Upsert("m1", &storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Upsert(storage.SysInfoKey, &storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Errorf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestWarningf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Warningf("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestInfof(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Infof("test", 1, 2, 3)
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
// coverage: one day check log hook
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.Debugf("test", 1, 2, 3)
|
||||
}
|
||||
474
backend/services/mochi/hooks/storage/bolt/bolt.go
Normal file
474
backend/services/mochi/hooks/storage/bolt/bolt.go
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
// package bolt is provided for historical compatibility and may not be actively updated, you should use the badger hook instead.
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
sgob "github.com/asdine/storm/codec/gob"
|
||||
"github.com/asdine/storm/v3"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultDbFile is the default file path for the boltdb file.
|
||||
defaultDbFile = "bolt.db"
|
||||
|
||||
// defaultTimeout is the default time to hold a connection to the file.
|
||||
defaultTimeout = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return storage.SubscriptionKey + "_" + cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return storage.RetainedKey + "_" + topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return storage.InflightKey + "_" + cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the bolt instance.
|
||||
type Options struct {
|
||||
Options *bbolt.Options
|
||||
Path string
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using boltdb file store as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for configuring the boltdb instance.
|
||||
db *storm.DB // the boltdb instance.
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "bolt-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// Init initializes and connects to the boltdb instance.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
config = new(Options)
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.Options == nil {
|
||||
h.config.Options = &bbolt.Options{
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
if h.config.Path == "" {
|
||||
h.config.Path = defaultDbFile
|
||||
}
|
||||
|
||||
var err error
|
||||
h.db, err = storm.Open(h.config.Path, storm.BoltOptions(0600, h.config.Options), storm.Codec(sgob.Codec))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop closes the boltdb instance.
|
||||
func (h *Hook) Stop() error {
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to save client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.DeleteStruct(&storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", subscriptionKey(cl, pk.Filters[i].Filter)).
|
||||
Msg("failed to delete client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.DeleteStruct(&storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", retainedKey(pk.TopicName)).
|
||||
Msg("failed to delete retained publish")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save retained publish data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("client", cl.ID).
|
||||
Interface("data", in).
|
||||
Msg("failed to save qos inflight data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
})
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Str("id", inflightKey(cl, pk)).
|
||||
Msg("failed to delete inflight data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.Save(in)
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).
|
||||
Interface("data", in).
|
||||
Msg("failed to save $SYS data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteStruct(&storage.Message{ID: retainedKey(filter)}); err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained publish")
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.DeleteStruct(&storage.Client{ID: clientKey(cl)})
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.ClientKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.SubscriptionKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.RetainedKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.Find("T", storage.InflightKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.db.One("ID", storage.SysInfoKey, &v)
|
||||
if err != nil && !errors.Is(err, storm.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
717
backend/services/mochi/hooks/storage/bolt/bolt_test.go
Normal file
717
backend/services/mochi/hooks/storage/bolt/bolt_test.go
Normal file
|
|
@ -0,0 +1,717 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package bolt
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/asdine/storm/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func teardown(t *testing.T, path string, h *Hook) {
|
||||
h.Stop()
|
||||
err := os.Remove(path)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, storage.SubscriptionKey+"_cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, storage.RetainedKey+"_a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, storage.InflightKey+"_cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.Equal(t, "bolt-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
require.Equal(t, defaultTimeout, h.config.Options.Timeout)
|
||||
require.Equal(t, defaultDbFile, h.config.Path)
|
||||
}
|
||||
|
||||
func TestInitBadPath(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Path: "..",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r3)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey(client), r)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err = h.db.Save(&storage.Client{ID: cl.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
err = h.db.One("ID", clientKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cl.ID, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
err = h.db.One("ID", clientKey, r)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
r := new(storage.Subscription)
|
||||
|
||||
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
err = h.db.One("ID", subscriptionKey(client, pkf.Filters[0].Filter), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.One("ID", retainedKey(pk.TopicName), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.One("ID", retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
err = h.db.One("ID", retainedKey(pk.TopicName), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err = h.db.Save(m)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.One("ID", m.ID, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
err = h.db.One("ID", m.ID, r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
err = h.db.One("ID", inflightKey(client, pk), r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved to bolt
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
err = h.db.One("ID", inflightKey(client, pk), r)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, storm.ErrNotFound, err)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
err = h.db.One("ID", storage.SysInfoKey, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with clients
|
||||
err = h.db.Save(&storage.Client{ID: "cl1", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Client{ID: "cl2", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Client{ID: "cl3", T: storage.ClientKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err = h.db.Save(&storage.Subscription{ID: "sub1", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Subscription{ID: "sub2", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Subscription{ID: "sub3", T: storage.SubscriptionKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "m2", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "m3", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with messages
|
||||
err = h.db.Save(&storage.Message{ID: "i1", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i2", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "i3", T: storage.InflightKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.Save(&storage.Message{ID: "m1", T: storage.RetainedKey})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h.config.Path, h)
|
||||
|
||||
// populate with sys info
|
||||
err = h.db.Save(&storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
teardown(t, h.config.Path, h)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
513
backend/services/mochi/hooks/storage/redis/redis.go
Normal file
513
backend/services/mochi/hooks/storage/redis/redis.go
Normal file
|
|
@ -0,0 +1,513 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
// defaultAddr is the default address to the redis service.
|
||||
const defaultAddr = "localhost:6379"
|
||||
|
||||
// defaultHPrefix is a prefix to better identify hsets created by mochi mqtt.
|
||||
const defaultHPrefix = "mochi-"
|
||||
|
||||
// clientKey returns a primary key for a client.
|
||||
func clientKey(cl *mqtt.Client) string {
|
||||
return cl.ID
|
||||
}
|
||||
|
||||
// subscriptionKey returns a primary key for a subscription.
|
||||
func subscriptionKey(cl *mqtt.Client, filter string) string {
|
||||
return cl.ID + ":" + filter
|
||||
}
|
||||
|
||||
// retainedKey returns a primary key for a retained message.
|
||||
func retainedKey(topic string) string {
|
||||
return topic
|
||||
}
|
||||
|
||||
// inflightKey returns a primary key for an inflight message.
|
||||
func inflightKey(cl *mqtt.Client, pk packets.Packet) string {
|
||||
return cl.ID + ":" + pk.FormatID()
|
||||
}
|
||||
|
||||
// sysInfoKey returns a primary key for system info.
|
||||
func sysInfoKey() string {
|
||||
return storage.SysInfoKey
|
||||
}
|
||||
|
||||
// Options contains configuration settings for the bolt instance.
|
||||
type Options struct {
|
||||
HPrefix string
|
||||
Options *redis.Options
|
||||
}
|
||||
|
||||
// Hook is a persistent storage hook based using Redis as a backend.
|
||||
type Hook struct {
|
||||
mqtt.HookBase
|
||||
config *Options // options for connecting to the Redis instance.
|
||||
db *redis.Client // the Redis instance
|
||||
ctx context.Context // a context for the connection
|
||||
}
|
||||
|
||||
// ID returns the id of the hook.
|
||||
func (h *Hook) ID() string {
|
||||
return "redis-db"
|
||||
}
|
||||
|
||||
// Provides indicates which hook methods this hook provides.
|
||||
func (h *Hook) Provides(b byte) bool {
|
||||
return bytes.Contains([]byte{
|
||||
mqtt.OnSessionEstablished,
|
||||
mqtt.OnDisconnect,
|
||||
mqtt.OnSubscribed,
|
||||
mqtt.OnUnsubscribed,
|
||||
mqtt.OnRetainMessage,
|
||||
mqtt.OnQosPublish,
|
||||
mqtt.OnQosComplete,
|
||||
mqtt.OnQosDropped,
|
||||
mqtt.OnWillSent,
|
||||
mqtt.OnSysInfoTick,
|
||||
mqtt.OnClientExpired,
|
||||
mqtt.OnRetainedExpired,
|
||||
mqtt.StoredClients,
|
||||
mqtt.StoredInflightMessages,
|
||||
mqtt.StoredRetainedMessages,
|
||||
mqtt.StoredSubscriptions,
|
||||
mqtt.StoredSysInfo,
|
||||
}, []byte{b})
|
||||
}
|
||||
|
||||
// hKey returns a hash set key with a unique prefix.
|
||||
func (h *Hook) hKey(s string) string {
|
||||
return h.config.HPrefix + s
|
||||
}
|
||||
|
||||
// Init initializes and connects to the redis service.
|
||||
func (h *Hook) Init(config any) error {
|
||||
if _, ok := config.(*Options); !ok && config != nil {
|
||||
return mqtt.ErrInvalidConfigType
|
||||
}
|
||||
|
||||
h.ctx = context.Background()
|
||||
|
||||
if config == nil {
|
||||
config = &Options{
|
||||
Options: &redis.Options{
|
||||
Addr: defaultAddr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
h.config = config.(*Options)
|
||||
if h.config.HPrefix == "" {
|
||||
h.config.HPrefix = defaultHPrefix
|
||||
}
|
||||
|
||||
h.Log.Info().
|
||||
Str("address", h.config.Options.Addr).
|
||||
Str("username", h.config.Options.Username).
|
||||
Int("password-len", len(h.config.Options.Password)).
|
||||
Int("db", h.config.Options.DB).
|
||||
Msg("connecting to redis service")
|
||||
|
||||
h.db = redis.NewClient(h.config.Options)
|
||||
_, err := h.db.Ping(context.Background()).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ping service: %w", err)
|
||||
}
|
||||
|
||||
h.Log.Info().Msg("connected to redis service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the redis connection.
|
||||
func (h *Hook) Stop() error {
|
||||
h.Log.Info().Msg("disconnecting from redis service")
|
||||
return h.db.Close()
|
||||
}
|
||||
|
||||
// OnSessionEstablished adds a client to the store when their session is established.
|
||||
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// OnWillSent is called when a client sends a will message and the will message is removed
|
||||
// from the client record.
|
||||
func (h *Hook) OnWillSent(cl *mqtt.Client, pk packets.Packet) {
|
||||
h.updateClient(cl)
|
||||
}
|
||||
|
||||
// updateClient writes the client data to the store.
|
||||
func (h *Hook) updateClient(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := cl.Properties.Props.Copy(false)
|
||||
in := &storage.Client{
|
||||
ID: clientKey(cl),
|
||||
T: storage.ClientKey,
|
||||
Remote: cl.Net.Remote,
|
||||
Listener: cl.Net.Listener,
|
||||
Username: cl.Properties.Username,
|
||||
Clean: cl.Properties.Clean,
|
||||
ProtocolVersion: cl.Properties.ProtocolVersion,
|
||||
Properties: storage.ClientProperties{
|
||||
SessionExpiryInterval: props.SessionExpiryInterval,
|
||||
AuthenticationMethod: props.AuthenticationMethod,
|
||||
AuthenticationData: props.AuthenticationData,
|
||||
RequestProblemInfo: props.RequestProblemInfo,
|
||||
RequestResponseInfo: props.RequestResponseInfo,
|
||||
ReceiveMaximum: props.ReceiveMaximum,
|
||||
TopicAliasMaximum: props.TopicAliasMaximum,
|
||||
User: props.User,
|
||||
MaximumPacketSize: props.MaximumPacketSize,
|
||||
},
|
||||
Will: storage.ClientWill(cl.Properties.Will),
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey(cl), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset client data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnDisconnect removes a client from the store if they were using a clean session.
|
||||
func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if !expire {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete client")
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubscribed adds one or more client subscriptions to the store.
|
||||
func (h *Hook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
var in *storage.Subscription
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
in = &storage.Subscription{
|
||||
ID: subscriptionKey(cl, pk.Filters[i].Filter),
|
||||
T: storage.SubscriptionKey,
|
||||
Client: cl.ID,
|
||||
Qos: reasonCodes[i],
|
||||
Filter: pk.Filters[i].Filter,
|
||||
Identifier: pk.Filters[i].Identifier,
|
||||
NoLocal: pk.Filters[i].NoLocal,
|
||||
RetainHandling: pk.Filters[i].RetainHandling,
|
||||
RetainAsPublished: pk.Filters[i].RetainAsPublished,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnUnsubscribed removes one or more client subscriptions from the store.
|
||||
func (h *Hook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(pk.Filters); i++ {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(cl, pk.Filters[i].Filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete subscription data")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainMessage adds a retained message for a topic to the store.
|
||||
func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
if r == -1 {
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete retained message data")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: retainedKey(pk.TopicName),
|
||||
T: storage.RetainedKey,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Created: pk.Created,
|
||||
Origin: pk.Origin,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosPublish adds or updates an inflight message in the store.
|
||||
func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
props := pk.Properties.Copy(false)
|
||||
in := &storage.Message{
|
||||
ID: inflightKey(cl, pk),
|
||||
T: storage.InflightKey,
|
||||
Origin: pk.Origin,
|
||||
FixedHeader: pk.FixedHeader,
|
||||
TopicName: pk.TopicName,
|
||||
Payload: pk.Payload,
|
||||
Sent: sent,
|
||||
Created: pk.Created,
|
||||
Properties: storage.MessageProperties{
|
||||
PayloadFormat: props.PayloadFormat,
|
||||
MessageExpiryInterval: props.MessageExpiryInterval,
|
||||
ContentType: props.ContentType,
|
||||
ResponseTopic: props.ResponseTopic,
|
||||
CorrelationData: props.CorrelationData,
|
||||
SubscriptionIdentifier: props.SubscriptionIdentifier,
|
||||
TopicAlias: props.TopicAlias,
|
||||
User: props.User,
|
||||
},
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset qos inflight message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosComplete removes a resolved inflight message from the store.
|
||||
func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.InflightKey), inflightKey(cl, pk)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete inflight message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnQosDropped removes a dropped inflight message from the store.
|
||||
func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
}
|
||||
|
||||
h.OnQosComplete(cl, pk)
|
||||
}
|
||||
|
||||
// OnSysInfoTick stores the latest system info in the store.
|
||||
func (h *Hook) OnSysInfoTick(sys *system.Info) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
in := &storage.SystemInfo{
|
||||
ID: sysInfoKey(),
|
||||
T: storage.SysInfoKey,
|
||||
Info: *sys,
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), sysInfoKey(), in).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Interface("data", in).Msg("failed to hset server info data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnRetainedExpired deletes expired retained messages from the store.
|
||||
func (h *Hook) OnRetainedExpired(filter string) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.RetainedKey), retainedKey(filter)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", retainedKey(filter)).Msg("failed to delete retained message data")
|
||||
}
|
||||
}
|
||||
|
||||
// OnClientExpired deleted expired clients from the store.
|
||||
func (h *Hook) OnClientExpired(cl *mqtt.Client) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.db.HDel(h.ctx, h.hKey(storage.ClientKey), clientKey(cl)).Err()
|
||||
if err != nil {
|
||||
h.Log.Error().Err(err).Str("id", clientKey(cl)).Msg("failed to delete expired client")
|
||||
}
|
||||
}
|
||||
|
||||
// StoredClients returns all stored clients from the store.
|
||||
func (h *Hook) StoredClients() (v []storage.Client, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.ClientKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll client data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Client
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal client data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSubscriptions returns all stored subscriptions from the store.
|
||||
func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.SubscriptionKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll subscription data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Subscription
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal subscription data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredRetainedMessages returns all stored retained messages from the store.
|
||||
func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.RetainedKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll retained message data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal retained message data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredInflightMessages returns all stored inflight messages from the store.
|
||||
func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := h.db.HGetAll(h.ctx, h.hKey(storage.InflightKey)).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
h.Log.Error().Err(err).Msg("failed to HGetAll inflight message data")
|
||||
return
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var d storage.Message
|
||||
if err = d.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal inflight message data")
|
||||
}
|
||||
|
||||
v = append(v, d)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// StoredSysInfo returns the system info from the store.
|
||||
func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.db == nil {
|
||||
h.Log.Error().Err(storage.ErrDBFileNotOpen)
|
||||
return
|
||||
}
|
||||
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = v.UnmarshalBinary([]byte(row)); err != nil {
|
||||
h.Log.Error().Err(err).Str("data", row).Msg("failed to unmarshal sys info data")
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
789
backend/services/mochi/hooks/storage/redis/redis_test.go
Normal file
789
backend/services/mochi/hooks/storage/redis/redis_test.go
Normal file
|
|
@ -0,0 +1,789 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2"
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
redis "github.com/go-redis/redis/v8"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
client = &mqtt.Client{
|
||||
ID: "test",
|
||||
Net: mqtt.ClientConnection{
|
||||
Remote: "test.addr",
|
||||
Listener: "listener",
|
||||
},
|
||||
Properties: mqtt.ClientProperties{
|
||||
Username: []byte("username"),
|
||||
Clean: false,
|
||||
},
|
||||
}
|
||||
|
||||
pkf = packets.Packet{Filters: packets.Subscriptions{{Filter: "a/b/c"}}}
|
||||
)
|
||||
|
||||
func newHook(t *testing.T, addr string) *Hook {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: addr,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func teardown(t *testing.T, h *Hook) {
|
||||
if h.db != nil {
|
||||
err := h.db.FlushAll(h.ctx).Err()
|
||||
require.NoError(t, err)
|
||||
h.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKey(t *testing.T) {
|
||||
k := clientKey(&mqtt.Client{ID: "cl1"})
|
||||
require.Equal(t, "cl1", k)
|
||||
}
|
||||
|
||||
func TestSubscriptionKey(t *testing.T) {
|
||||
k := subscriptionKey(&mqtt.Client{ID: "cl1"}, "a/b/c")
|
||||
require.Equal(t, "cl1:a/b/c", k)
|
||||
}
|
||||
|
||||
func TestRetainedKey(t *testing.T) {
|
||||
k := retainedKey("a/b/c")
|
||||
require.Equal(t, "a/b/c", k)
|
||||
}
|
||||
|
||||
func TestInflightKey(t *testing.T) {
|
||||
k := inflightKey(&mqtt.Client{ID: "cl1"}, packets.Packet{PacketID: 1})
|
||||
require.Equal(t, "cl1:1", k)
|
||||
}
|
||||
|
||||
func TestSysInfoKey(t *testing.T) {
|
||||
require.Equal(t, storage.SysInfoKey, sysInfoKey())
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
require.Equal(t, "redis-db", h.ID())
|
||||
}
|
||||
|
||||
func TestProvides(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
require.True(t, h.Provides(mqtt.OnSessionEstablished))
|
||||
require.True(t, h.Provides(mqtt.OnDisconnect))
|
||||
require.True(t, h.Provides(mqtt.OnSubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnUnsubscribed))
|
||||
require.True(t, h.Provides(mqtt.OnRetainMessage))
|
||||
require.True(t, h.Provides(mqtt.OnQosPublish))
|
||||
require.True(t, h.Provides(mqtt.OnQosComplete))
|
||||
require.True(t, h.Provides(mqtt.OnQosDropped))
|
||||
require.True(t, h.Provides(mqtt.OnSysInfoTick))
|
||||
require.True(t, h.Provides(mqtt.StoredClients))
|
||||
require.True(t, h.Provides(mqtt.StoredInflightMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredRetainedMessages))
|
||||
require.True(t, h.Provides(mqtt.StoredSubscriptions))
|
||||
require.True(t, h.Provides(mqtt.StoredSysInfo))
|
||||
require.False(t, h.Provides(mqtt.OnACLCheck))
|
||||
require.False(t, h.Provides(mqtt.OnConnectAuthenticate))
|
||||
}
|
||||
|
||||
func TestHKey(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.SetOpts(&logger, nil)
|
||||
require.Equal(t, defaultHPrefix+"test", h.hKey("test"))
|
||||
}
|
||||
|
||||
func TestInitUseDefaults(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
s.StartAddr(defaultAddr)
|
||||
defer s.Close()
|
||||
|
||||
h := newHook(t, defaultAddr)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(nil)
|
||||
require.NoError(t, err)
|
||||
defer teardown(t, h)
|
||||
|
||||
require.Equal(t, defaultHPrefix, h.config.HPrefix)
|
||||
require.Equal(t, defaultAddr, h.config.Options.Addr)
|
||||
}
|
||||
|
||||
func TestInitBadConfig(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
|
||||
err := h.Init(map[string]any{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestInitBadAddr(t *testing.T) {
|
||||
h := new(Hook)
|
||||
h.SetOpts(&logger, nil)
|
||||
err := h.Init(&Options{
|
||||
Options: &redis.Options{
|
||||
Addr: "abc:123",
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedThenOnDisconnect(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
require.Equal(t, client.Net.Remote, r.Remote)
|
||||
require.Equal(t, client.Net.Listener, r.Listener)
|
||||
require.Equal(t, client.Properties.Username, r.Username)
|
||||
require.Equal(t, client.Properties.Clean, r.Clean)
|
||||
require.NotSame(t, client, r)
|
||||
|
||||
h.OnDisconnect(client, nil, false)
|
||||
r2 := new(storage.Client)
|
||||
row, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r2.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.ID)
|
||||
|
||||
h.OnDisconnect(client, nil, true)
|
||||
r3 := new(storage.Client)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
require.Empty(t, r3.ID)
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
|
||||
h.db = nil
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSessionEstablishedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSessionEstablished(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnWillSent(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
c1 := client
|
||||
c1.Properties.Will.Flag = 1
|
||||
h.OnWillSent(c1, packets.Packet{})
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey(client)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint32(1), r.Will.Flag)
|
||||
require.NotSame(t, client, r)
|
||||
}
|
||||
|
||||
func TestOnClientExpired(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
cl := &mqtt.Client{ID: "cl1"}
|
||||
clientKey := clientKey(cl)
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), clientKey, &storage.Client{ID: cl.ID}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Client)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientKey, r.ID)
|
||||
|
||||
h.OnClientExpired(cl)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.ClientKey), clientKey).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, redis.Nil, err)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnClientExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnClientExpired(client)
|
||||
}
|
||||
|
||||
func TestOnDisconnectNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnDisconnectClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnDisconnect(client, nil, false)
|
||||
}
|
||||
|
||||
func TestOnSubscribedThenOnUnsubscribed(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
|
||||
r := new(storage.Subscription)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, client.ID, r.Client)
|
||||
require.Equal(t, pkf.Filters[0].Filter, r.Filter)
|
||||
require.Equal(t, byte(0), r.Qos)
|
||||
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.SubscriptionKey), subscriptionKey(client, pkf.Filters[0].Filter)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnSubscribedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnSubscribedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSubscribed(client, pkf, []byte{0})
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnUnsubscribedClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnUnsubscribed(client, pkf)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageThenUnset(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnRetainMessage(client, pk, 1)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
|
||||
// coverage: delete deleted
|
||||
h.OnRetainMessage(client, pk, -1)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), retainedKey(pk.TopicName)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpired(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
m := &storage.Message{
|
||||
ID: retainedKey("a/b/c"),
|
||||
T: storage.RetainedKey,
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), m.ID, m).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, m.TopicName, r.TopicName)
|
||||
|
||||
h.OnRetainedExpired(m.TopicName)
|
||||
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.RetainedKey), m.ID).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainedExpiredNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
}
|
||||
|
||||
func TestOnRetainMessageNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnRetainMessageClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnRetainMessage(client, packets.Packet{}, 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishThenQOSComplete(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Retain: true,
|
||||
Qos: 2,
|
||||
},
|
||||
Payload: []byte("hello"),
|
||||
TopicName: "a/b/c",
|
||||
}
|
||||
|
||||
h.OnQosPublish(client, pk, time.Now().Unix(), 0)
|
||||
|
||||
r := new(storage.Message)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.TopicName, r.TopicName)
|
||||
require.Equal(t, pk.Payload, r.Payload)
|
||||
|
||||
// ensure dates are properly saved to bolt
|
||||
require.True(t, r.Sent > 0)
|
||||
require.True(t, time.Now().Unix()-1 < r.Sent)
|
||||
|
||||
// OnQosDropped is a passthrough to OnQosComplete here
|
||||
h.OnQosDropped(client, pk)
|
||||
_, err = h.db.HGet(h.ctx, h.hKey(storage.InflightKey), inflightKey(client, pk)).Result()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, redis.Nil)
|
||||
}
|
||||
|
||||
func TestOnQosPublishNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosPublishClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnQosPublish(client, packets.Packet{}, time.Now().Unix(), 0)
|
||||
}
|
||||
|
||||
func TestOnQosCompleteNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosCompleteClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnQosComplete(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnQosDroppedNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnQosDropped(client, packets.Packet{})
|
||||
}
|
||||
|
||||
func TestOnSysInfoTick(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
info := &system.Info{
|
||||
Version: "2.0.0",
|
||||
BytesReceived: 100,
|
||||
}
|
||||
|
||||
h.OnSysInfoTick(info)
|
||||
|
||||
r := new(storage.SystemInfo)
|
||||
row, err := h.db.HGet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey).Result()
|
||||
require.NoError(t, err)
|
||||
err = r.UnmarshalBinary([]byte(row))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, info.Version, r.Version)
|
||||
require.Equal(t, info.BytesReceived, r.BytesReceived)
|
||||
require.NotSame(t, info, r)
|
||||
}
|
||||
|
||||
func TestOnSysInfoTickClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
func TestOnSysInfoTickNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
}
|
||||
|
||||
func TestStoredClients(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with clients
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl1", &storage.Client{ID: "cl1", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl2", &storage.Client{ID: "cl2", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.ClientKey), "cl3", &storage.Client{ID: "cl3", T: storage.ClientKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "cl1", r[0].ID)
|
||||
require.Equal(t, "cl2", r[1].ID)
|
||||
require.Equal(t, "cl3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredClientsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredClientsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptions(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with subscriptions
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub1", &storage.Subscription{ID: "sub1", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub2", &storage.Subscription{ID: "sub2", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.SubscriptionKey), "sub3", &storage.Subscription{ID: "sub3", T: storage.SubscriptionKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "sub1", r[0].ID)
|
||||
require.Equal(t, "sub2", r[1].ID)
|
||||
require.Equal(t, "sub3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSubscriptionsClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessages(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with messages
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m1", &storage.Message{ID: "m1", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m2", &storage.Message{ID: "m2", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "m1", r[0].ID)
|
||||
require.Equal(t, "m2", r[1].ID)
|
||||
require.Equal(t, "m3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredRetainedMessagesClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessages(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with messages
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i1", &storage.Message{ID: "i1", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i2", &storage.Message{ID: "i2", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.InflightKey), "i3", &storage.Message{ID: "i3", T: storage.InflightKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.db.HSet(h.ctx, h.hKey(storage.RetainedKey), "m3", &storage.Message{ID: "m3", T: storage.RetainedKey}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r, 3)
|
||||
sort.Slice(r[:], func(i, j int) bool { return r[i].ID < r[j].ID })
|
||||
require.Equal(t, "i1", r[0].ID)
|
||||
require.Equal(t, "i2", r[1].ID)
|
||||
require.Equal(t, "i3", r[2].ID)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredInflightMessagesClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfo(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
defer teardown(t, h)
|
||||
|
||||
// populate with sys info
|
||||
err := h.db.HSet(h.ctx, h.hKey(storage.SysInfoKey), storage.SysInfoKey,
|
||||
&storage.SystemInfo{
|
||||
ID: storage.SysInfoKey,
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
T: storage.SysInfoKey,
|
||||
}).Err()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", r.Info.Version)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoNoDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
h.db = nil
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestStoredSysInfoClosedDB(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
defer s.Close()
|
||||
h := newHook(t, s.Addr())
|
||||
teardown(t, h)
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.Empty(t, v)
|
||||
require.Error(t, err)
|
||||
}
|
||||
164
backend/services/mochi/hooks/storage/storage.go
Normal file
164
backend/services/mochi/hooks/storage/storage.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
)
|
||||
|
||||
const (
|
||||
SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store
|
||||
SysInfoKey = "SYS" // unique key to denote server system information in a store
|
||||
RetainedKey = "RET" // unique key to denote retained messages in a store
|
||||
InflightKey = "IFM" // unique key to denote inflight messages in a store
|
||||
ClientKey = "CL" // unique key to denote clients in a store
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading.
|
||||
ErrDBFileNotOpen = errors.New("db file not open")
|
||||
)
|
||||
|
||||
// Client is a storable representation of an mqtt client.
|
||||
type Client struct {
|
||||
Will ClientWill `json:"will"` // will topic and payload data if applicable
|
||||
Properties ClientProperties `json:"properties"` // the connect properties for the client
|
||||
Username []byte `json:"username"` // the username of the client
|
||||
ID string `json:"id" storm:"id"` // the client id / storage key
|
||||
T string `json:"t"` // the data type (client)
|
||||
Remote string `json:"remote"` // the remote address of the client
|
||||
Listener string `json:"listener"` // the listener the client connected on
|
||||
ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client
|
||||
Clean bool `json:"clean"` // if the client requested a clean start/session
|
||||
}
|
||||
|
||||
// ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection.
|
||||
type ClientProperties struct {
|
||||
AuthenticationData []byte `json:"authenticationData"`
|
||||
User []packets.UserProperty `json:"user"`
|
||||
AuthenticationMethod string `json:"authenticationMethod"`
|
||||
SessionExpiryInterval uint32 `json:"sessionExpiryInterval"`
|
||||
MaximumPacketSize uint32 `json:"maximumPacketSize"`
|
||||
ReceiveMaximum uint16 `json:"receiveMaximum"`
|
||||
TopicAliasMaximum uint16 `json:"topicAliasMaximum"`
|
||||
SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag"`
|
||||
RequestProblemInfo byte `json:"requestProblemInfo"`
|
||||
RequestProblemInfoFlag bool `json:"requestProblemInfoFlag"`
|
||||
RequestResponseInfo byte `json:"requestResponseInfo"`
|
||||
}
|
||||
|
||||
// ClientWill contains a will message for a client, and limited mqtt v5 properties.
|
||||
type ClientWill struct {
|
||||
Payload []byte `json:"payload"`
|
||||
User []packets.UserProperty `json:"user"`
|
||||
TopicName string `json:"topicName"`
|
||||
Flag uint32 `json:"flag"`
|
||||
WillDelayInterval uint32 `json:"willDelayInterval"`
|
||||
Qos byte `json:"qos"`
|
||||
Retain bool `json:"retain"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Client) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Client) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// Message is a storable representation of an MQTT message (specifically publish).
|
||||
type Message struct {
|
||||
Properties MessageProperties `json:"properties"` // -
|
||||
Payload []byte `json:"payload"` // the message payload (if retained)
|
||||
T string `json:"t"` // the data type
|
||||
ID string `json:"id" storm:"id"` // the storage key
|
||||
Origin string `json:"origin"` // the id of the client who sent the message
|
||||
TopicName string `json:"topic_name"` // the topic the message was sent to (if retained)
|
||||
FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message
|
||||
Created int64 `json:"created"` // the time the message was created in unixtime
|
||||
Sent int64 `json:"sent"` // the last time the message was sent (for retries) in unixtime (if inflight)
|
||||
PacketID uint16 `json:"packet_id"` // the unique id of the packet (if inflight)
|
||||
}
|
||||
|
||||
// MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages.
|
||||
type MessageProperties struct {
|
||||
CorrelationData []byte `json:"correlationData"`
|
||||
SubscriptionIdentifier []int `json:"subscriptionIdentifier"`
|
||||
User []packets.UserProperty `json:"user"`
|
||||
ContentType string `json:"contentType"`
|
||||
ResponseTopic string `json:"responseTopic"`
|
||||
MessageExpiryInterval uint32 `json:"messageExpiry"`
|
||||
TopicAlias uint16 `json:"topicAlias"`
|
||||
PayloadFormat byte `json:"payloadFormat"`
|
||||
PayloadFormatFlag bool `json:"payloadFormatFlag"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Message) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Message) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// Subscription is a storable representation of an mqtt subscription.
|
||||
type Subscription struct {
|
||||
T string `json:"t"`
|
||||
ID string `json:"id" storm:"id"`
|
||||
Client string `json:"client"`
|
||||
Filter string `json:"filter"`
|
||||
Identifier int `json:"identifier"`
|
||||
RetainHandling byte `json:"retain_handling"`
|
||||
Qos byte `json:"qos"`
|
||||
RetainAsPublished bool `json:"retain_as_pub"`
|
||||
NoLocal bool `json:"no_local"`
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d Subscription) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *Subscription) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
|
||||
// SystemInfo is a storable representation of the system information values.
|
||||
type SystemInfo struct {
|
||||
system.Info // embed the system info struct
|
||||
T string `json:"t"` // the data type
|
||||
ID string `json:"id" storm:"id"` // the storage key
|
||||
}
|
||||
|
||||
// MarshalBinary encodes the values into a json string.
|
||||
func (d SystemInfo) MarshalBinary() (data []byte, err error) {
|
||||
return json.Marshal(d)
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes a json string into a struct.
|
||||
func (d *SystemInfo) UnmarshalBinary(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(data, d)
|
||||
}
|
||||
196
backend/services/mochi/hooks/storage/storage_test.go
Normal file
196
backend/services/mochi/hooks/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
clientStruct = Client{
|
||||
ID: "test",
|
||||
T: "client",
|
||||
Remote: "remote",
|
||||
Listener: "listener",
|
||||
Username: []byte("mochi"),
|
||||
Clean: true,
|
||||
Properties: ClientProperties{
|
||||
SessionExpiryInterval: 2,
|
||||
SessionExpiryIntervalFlag: true,
|
||||
AuthenticationMethod: "a",
|
||||
AuthenticationData: []byte("test"),
|
||||
RequestProblemInfo: 1,
|
||||
RequestProblemInfoFlag: true,
|
||||
RequestResponseInfo: 1,
|
||||
ReceiveMaximum: 128,
|
||||
TopicAliasMaximum: 256,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k", Val: "v"},
|
||||
},
|
||||
MaximumPacketSize: 120,
|
||||
},
|
||||
Will: ClientWill{
|
||||
Qos: 1,
|
||||
Payload: []byte("abc"),
|
||||
TopicName: "a/b/c",
|
||||
Flag: 1,
|
||||
Retain: true,
|
||||
WillDelayInterval: 2,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k2", Val: "v2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`)
|
||||
|
||||
messageStruct = Message{
|
||||
T: "message",
|
||||
Payload: []byte("payload"),
|
||||
FixedHeader: packets.FixedHeader{
|
||||
Remaining: 2,
|
||||
Type: 3,
|
||||
Qos: 1,
|
||||
Dup: true,
|
||||
Retain: true,
|
||||
},
|
||||
ID: "id",
|
||||
Origin: "mochi",
|
||||
TopicName: "topic",
|
||||
Properties: MessageProperties{
|
||||
PayloadFormat: 1,
|
||||
PayloadFormatFlag: true,
|
||||
MessageExpiryInterval: 20,
|
||||
ContentType: "type",
|
||||
ResponseTopic: "a/b/r",
|
||||
CorrelationData: []byte("r"),
|
||||
SubscriptionIdentifier: []int{1},
|
||||
TopicAlias: 2,
|
||||
User: []packets.UserProperty{
|
||||
{Key: "k2", Val: "v2"},
|
||||
},
|
||||
},
|
||||
Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
|
||||
Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(),
|
||||
PacketID: 100,
|
||||
}
|
||||
messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`)
|
||||
|
||||
subscriptionStruct = Subscription{
|
||||
T: "subscription",
|
||||
ID: "id",
|
||||
Client: "mochi",
|
||||
Filter: "a/b/c",
|
||||
Qos: 1,
|
||||
}
|
||||
subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","identifier":0,"retain_handling":0,"qos":1,"retain_as_pub":false,"no_local":false}`)
|
||||
|
||||
sysInfoStruct = SystemInfo{
|
||||
T: "info",
|
||||
ID: "id",
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
Started: 1,
|
||||
Uptime: 2,
|
||||
BytesReceived: 3,
|
||||
BytesSent: 4,
|
||||
ClientsConnected: 5,
|
||||
ClientsMaximum: 7,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
PacketsReceived: 12,
|
||||
PacketsSent: 13,
|
||||
Retained: 15,
|
||||
Inflight: 16,
|
||||
InflightDropped: 17,
|
||||
},
|
||||
}
|
||||
sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`)
|
||||
)
|
||||
|
||||
func TestClientMarshalBinary(t *testing.T) {
|
||||
data, err := clientStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientJSON, data)
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinary(t *testing.T) {
|
||||
d := clientStruct
|
||||
err := d.UnmarshalBinary(clientJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, clientStruct, d)
|
||||
}
|
||||
|
||||
func TestClientUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Client{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Client{}, d)
|
||||
}
|
||||
|
||||
func TestMessageMarshalBinary(t *testing.T) {
|
||||
data, err := messageStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageJSON, data)
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinary(t *testing.T) {
|
||||
d := messageStruct
|
||||
err := d.UnmarshalBinary(messageJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, messageStruct, d)
|
||||
}
|
||||
|
||||
func TestMessageUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Message{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Message{}, d)
|
||||
}
|
||||
|
||||
func TestSubscriptionMarshalBinary(t *testing.T) {
|
||||
data, err := subscriptionStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionJSON, data)
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinary(t *testing.T) {
|
||||
d := subscriptionStruct
|
||||
err := d.UnmarshalBinary(subscriptionJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, subscriptionStruct, d)
|
||||
}
|
||||
|
||||
func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := Subscription{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Subscription{}, d)
|
||||
}
|
||||
|
||||
func TestSysInfoMarshalBinary(t *testing.T) {
|
||||
data, err := sysInfoStruct.MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoJSON, data)
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinary(t *testing.T) {
|
||||
d := sysInfoStruct
|
||||
err := d.UnmarshalBinary(sysInfoJSON)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, sysInfoStruct, d)
|
||||
}
|
||||
|
||||
func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) {
|
||||
d := SystemInfo{}
|
||||
err := d.UnmarshalBinary([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SystemInfo{}, d)
|
||||
}
|
||||
634
backend/services/mochi/hooks_test.go
Normal file
634
backend/services/mochi/hooks_test.go
Normal file
|
|
@ -0,0 +1,634 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/hooks/storage"
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type modifiedHookBase struct {
|
||||
HookBase
|
||||
err error
|
||||
fail bool
|
||||
failAt int
|
||||
}
|
||||
|
||||
var errTestHook = errors.New("error")
|
||||
|
||||
func (h *modifiedHookBase) ID() string {
|
||||
return "modified"
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Init(config any) error {
|
||||
if config != nil {
|
||||
return errTestHook
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Provides(b byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) Stop() error {
|
||||
if h.fail {
|
||||
return errTestHook
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnACLCheck(cl *Client, topic string, write bool) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnPublish(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnPacketRead(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnAuthPacket(cl *Client, pk packets.Packet) (packets.Packet, error) {
|
||||
if h.fail {
|
||||
if h.err != nil {
|
||||
return pk, h.err
|
||||
}
|
||||
|
||||
return pk, errTestHook
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) OnWill(cl *Client, will Will) (Will, error) {
|
||||
if h.fail {
|
||||
return will, errTestHook
|
||||
}
|
||||
|
||||
return will, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredClients() (v []storage.Client, err error) {
|
||||
if h.fail || h.failAt == 1 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Client{
|
||||
{ID: "cl1"},
|
||||
{ID: "cl2"},
|
||||
{ID: "cl3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredSubscriptions() (v []storage.Subscription, err error) {
|
||||
if h.fail || h.failAt == 2 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Subscription{
|
||||
{ID: "sub1"},
|
||||
{ID: "sub2"},
|
||||
{ID: "sub3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredRetainedMessages() (v []storage.Message, err error) {
|
||||
if h.fail || h.failAt == 3 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Message{
|
||||
{ID: "r1"},
|
||||
{ID: "r2"},
|
||||
{ID: "r3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredInflightMessages() (v []storage.Message, err error) {
|
||||
if h.fail || h.failAt == 4 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return []storage.Message{
|
||||
{ID: "i1"},
|
||||
{ID: "i2"},
|
||||
{ID: "i3"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *modifiedHookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
|
||||
if h.fail || h.failAt == 5 {
|
||||
return v, errTestHook
|
||||
}
|
||||
|
||||
return storage.SystemInfo{
|
||||
Info: system.Info{
|
||||
Version: "2.0.0",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type providesCheckHook struct {
|
||||
HookBase
|
||||
}
|
||||
|
||||
func (h *providesCheckHook) Provides(b byte) bool {
|
||||
return b == OnConnect
|
||||
}
|
||||
|
||||
func TestHooksProvides(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(providesCheckHook), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, h.Provides(OnConnect, OnDisconnect))
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHooksAddLenGetAll(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(2), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(2), h.Len())
|
||||
|
||||
all := h.GetAll()
|
||||
require.Equal(t, "base", all[0].ID())
|
||||
require.Equal(t, "modified", all[1].ID())
|
||||
}
|
||||
|
||||
func TestHooksAddInitFailure(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), map[string]any{})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int64(0), atomic.LoadInt64(&h.qty))
|
||||
}
|
||||
|
||||
func TestHooksStop(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
err := h.Add(new(HookBase), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), atomic.LoadInt64(&h.qty))
|
||||
require.Equal(t, int64(1), h.Len())
|
||||
|
||||
h.Stop()
|
||||
}
|
||||
|
||||
// coverage: also cover some empty functions
|
||||
func TestHooksNonReturns(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
cl := new(Client)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
t.Run("step-"+strconv.Itoa(i), func(t *testing.T) {
|
||||
// on first iteration, check without hook methods
|
||||
h.OnStarted()
|
||||
h.OnStopped()
|
||||
h.OnSysInfoTick(new(system.Info))
|
||||
h.OnConnect(cl, packets.Packet{})
|
||||
h.OnSessionEstablished(cl, packets.Packet{})
|
||||
h.OnDisconnect(cl, nil, false)
|
||||
h.OnPacketSent(cl, packets.Packet{}, []byte{})
|
||||
h.OnPacketProcessed(cl, packets.Packet{}, nil)
|
||||
h.OnSubscribed(cl, packets.Packet{}, []byte{1})
|
||||
h.OnUnsubscribed(cl, packets.Packet{})
|
||||
h.OnPublished(cl, packets.Packet{})
|
||||
h.OnPublishDropped(cl, packets.Packet{})
|
||||
h.OnRetainMessage(cl, packets.Packet{}, 0)
|
||||
h.OnQosPublish(cl, packets.Packet{}, time.Now().Unix(), 0)
|
||||
h.OnQosComplete(cl, packets.Packet{})
|
||||
h.OnQosDropped(cl, packets.Packet{})
|
||||
h.OnWillSent(cl, packets.Packet{})
|
||||
h.OnClientExpired(cl)
|
||||
h.OnRetainedExpired("a/b/c")
|
||||
|
||||
// on second iteration, check added hook methods
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHooksOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
|
||||
ok := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, ok)
|
||||
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestHooksOnACLCheck(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
|
||||
ok := h.OnACLCheck(new(Client), "a/b/c", true)
|
||||
require.False(t, ok)
|
||||
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ok = h.OnACLCheck(new(Client), "a/b/c", true)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestHooksOnSubscribe(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pki := packets.Packet{
|
||||
Filters: packets.Subscriptions{
|
||||
{Filter: "a/b/c", Qos: 1},
|
||||
},
|
||||
}
|
||||
pk := h.OnSubscribe(new(Client), pki)
|
||||
require.EqualValues(t, pk, pki)
|
||||
}
|
||||
|
||||
func TestHooksOnSelectSubscribers(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
subs := &Subscribers{
|
||||
Subscriptions: map[string]packets.Subscription{
|
||||
"cl1": {Filter: "a/b/c"},
|
||||
},
|
||||
}
|
||||
|
||||
subs2 := h.OnSelectSubscribers(subs, packets.Packet{})
|
||||
require.EqualValues(t, subs, subs2)
|
||||
}
|
||||
|
||||
func TestHooksOnUnsubscribe(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
err := h.Add(new(modifiedHookBase), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pki := packets.Packet{
|
||||
Filters: packets.Subscriptions{
|
||||
{Filter: "a/b/c", Qos: 1},
|
||||
},
|
||||
}
|
||||
|
||||
pk := h.OnUnsubscribe(new(Client), pki)
|
||||
require.EqualValues(t, pk, pki)
|
||||
}
|
||||
|
||||
func TestHooksOnPublish(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
hook.err = packets.ErrRejectPacket
|
||||
pk, err = h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrRejectPacket)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketRead(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: failure
|
||||
hook.fail = true
|
||||
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
// coverage: reject packet
|
||||
hook.err = packets.ErrRejectPacket
|
||||
pk, err = h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, packets.ErrRejectPacket)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnAuthPacket(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
|
||||
hook.fail = true
|
||||
pk, err = h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnPacketEncode(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
pk := h.OnPacketEncode(new(Client), packets.Packet{PacketID: 10})
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHooksOnLWT(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err := h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
lwt := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
|
||||
// coverage: fail lwt
|
||||
hook.fail = true
|
||||
lwt = h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
}
|
||||
|
||||
func TestHooksStoredClients(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredClients()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredSubscriptions(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredSubscriptions()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredRetainedMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredRetainedMessages()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredInflightMessages(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 0)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v, 3)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredInflightMessages()
|
||||
require.Error(t, err)
|
||||
require.Len(t, v, 0)
|
||||
}
|
||||
|
||||
func TestHooksStoredSysInfo(t *testing.T) {
|
||||
h := new(Hooks)
|
||||
h.Log = &logger
|
||||
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", v.Info.Version)
|
||||
|
||||
hook := new(modifiedHookBase)
|
||||
err = h.Add(hook, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
v, err = h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0.0", v.Info.Version)
|
||||
|
||||
hook.fail = true
|
||||
v, err = h.StoredSysInfo()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "", v.Info.Version)
|
||||
}
|
||||
|
||||
func TestHookBaseID(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Equal(t, "base", h.ID())
|
||||
}
|
||||
|
||||
func TestHookBaseProvidesNone(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.False(t, h.Provides(OnConnect))
|
||||
require.False(t, h.Provides(OnDisconnect))
|
||||
}
|
||||
|
||||
func TestHookBaseInit(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Nil(t, h.Init(nil))
|
||||
}
|
||||
|
||||
func TestHookBaseSetOpts(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
h.SetOpts(&logger, new(HookOptions))
|
||||
require.NotNil(t, h.Log)
|
||||
require.NotNil(t, h.Opts)
|
||||
}
|
||||
|
||||
func TestHookBaseClose(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
require.Nil(t, h.Stop())
|
||||
}
|
||||
|
||||
func TestHookBaseOnConnectAuthenticate(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnConnectAuthenticate(new(Client), packets.Packet{})
|
||||
require.False(t, v)
|
||||
}
|
||||
func TestHookBaseOnACLCheck(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v := h.OnACLCheck(new(Client), "topic", true)
|
||||
require.False(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPublish(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPublish(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnPacketRead(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnPacketRead(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnAuthPacket(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
pk, err := h.OnAuthPacket(new(Client), packets.Packet{PacketID: 10})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint16(10), pk.PacketID)
|
||||
}
|
||||
|
||||
func TestHookBaseOnLWT(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
lwt, err := h.OnWill(new(Client), Will{TopicName: "a/b/c"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "a/b/c", lwt.TopicName)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredClients(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredClients()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredSubscriptions(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredSubscriptions()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredInflightMessages(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredInflightMessages()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoredRetainedMessages(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredRetainedMessages()
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, v)
|
||||
}
|
||||
|
||||
func TestHookBaseStoreSysInfo(t *testing.T) {
|
||||
h := new(HookBase)
|
||||
v, err := h.StoredSysInfo()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", v.Version)
|
||||
}
|
||||
156
backend/services/mochi/inflight.go
Normal file
156
backend/services/mochi/inflight.go
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
// Inflight is a map of InflightMessage keyed on packet id.
|
||||
type Inflight struct {
|
||||
sync.RWMutex
|
||||
internal map[uint16]packets.Packet // internal contains the inflight packets
|
||||
receiveQuota int32 // remaining inbound qos quota for flow control
|
||||
sendQuota int32 // remaining outbound qos quota for flow control
|
||||
maximumReceiveQuota int32 // maximum allowed receive quota
|
||||
maximumSendQuota int32 // maximum allowed send quota
|
||||
}
|
||||
|
||||
// NewInflights returns a new instance of an Inflight packets map.
|
||||
func NewInflights() *Inflight {
|
||||
return &Inflight{
|
||||
internal: map[uint16]packets.Packet{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds or updates an inflight packet by packet id.
|
||||
func (i *Inflight) Set(m packets.Packet) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
_, ok := i.internal[m.PacketID]
|
||||
i.internal[m.PacketID] = m
|
||||
return !ok
|
||||
}
|
||||
|
||||
// Get returns an inflight packet by packet id.
|
||||
func (i *Inflight) Get(id uint16) (packets.Packet, bool) {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
if m, ok := i.internal[id]; ok {
|
||||
return m, true
|
||||
}
|
||||
|
||||
return packets.Packet{}, false
|
||||
}
|
||||
|
||||
// Len returns the size of the inflight messages map.
|
||||
func (i *Inflight) Len() int {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
return len(i.internal)
|
||||
}
|
||||
|
||||
// Clone returns a new instance of Inflight with the same message data.
|
||||
// This is used when transferring inflights from a taken-over session.
|
||||
func (i *Inflight) Clone() *Inflight {
|
||||
c := NewInflights()
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
for k, v := range i.internal {
|
||||
c.internal[k] = v
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetAll returns all the inflight messages.
|
||||
func (i *Inflight) GetAll(immediate bool) []packets.Packet {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
m := []packets.Packet{}
|
||||
for _, v := range i.internal {
|
||||
if !immediate || (immediate && v.Expiry < 0) {
|
||||
m = append(m, v)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(m, func(i, j int) bool {
|
||||
return uint16(m[i].Created) < uint16(m[j].Created)
|
||||
})
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// NextImmediate returns the next inflight packet which is indicated to be sent immediately.
|
||||
// This typically occurs when the quota has been exhausted, and we need to wait until new quota
|
||||
// is free to continue sending.
|
||||
func (i *Inflight) NextImmediate() (packets.Packet, bool) {
|
||||
i.RLock()
|
||||
defer i.RUnlock()
|
||||
|
||||
m := i.GetAll(true)
|
||||
if len(m) > 0 {
|
||||
return m[0], true
|
||||
}
|
||||
|
||||
return packets.Packet{}, false
|
||||
}
|
||||
|
||||
// Delete removes an in-flight message from the map. Returns true if the message existed.
|
||||
func (i *Inflight) Delete(id uint16) bool {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
_, ok := i.internal[id]
|
||||
delete(i.internal, id)
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// TakeRecieveQuota reduces the receive quota by 1.
|
||||
func (i *Inflight) DecreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) > 0 {
|
||||
atomic.AddInt32(&i.receiveQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// TakeRecieveQuota increases the receive quota by 1.
|
||||
func (i *Inflight) IncreaseReceiveQuota() {
|
||||
if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) {
|
||||
atomic.AddInt32(&i.receiveQuota, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetReceiveQuota resets the receive quota to the maximum allowed value.
|
||||
func (i *Inflight) ResetReceiveQuota(n int32) {
|
||||
atomic.StoreInt32(&i.receiveQuota, n)
|
||||
atomic.StoreInt32(&i.maximumReceiveQuota, n)
|
||||
}
|
||||
|
||||
// DecreaseSendQuota reduces the send quota by 1.
|
||||
func (i *Inflight) DecreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) > 0 {
|
||||
atomic.AddInt32(&i.sendQuota, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// IncreaseSendQuota increases the send quota by 1.
|
||||
func (i *Inflight) IncreaseSendQuota() {
|
||||
if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) {
|
||||
atomic.AddInt32(&i.sendQuota, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetSendQuota resets the send quota to the maximum allowed value.
|
||||
func (i *Inflight) ResetSendQuota(n int32) {
|
||||
atomic.StoreInt32(&i.sendQuota, n)
|
||||
atomic.StoreInt32(&i.maximumSendQuota, n)
|
||||
}
|
||||
199
backend/services/mochi/inflight_test.go
Normal file
199
backend/services/mochi/inflight_test.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInflightSet(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
r := cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.True(t, r)
|
||||
require.NotNil(t, cl.State.Inflight.internal[1])
|
||||
require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID)
|
||||
|
||||
r = cl.State.Inflight.Set(packets.Packet{PacketID: 1})
|
||||
require.False(t, r)
|
||||
}
|
||||
|
||||
func TestInflightGet(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
|
||||
msg, ok := cl.State.Inflight.Get(2)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, 0, msg.PacketID)
|
||||
}
|
||||
|
||||
func TestInflightGetAllAndImmediate(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
|
||||
|
||||
require.Equal(t, []packets.Packet{
|
||||
{PacketID: 1, Created: 1},
|
||||
{PacketID: 2, Created: 2},
|
||||
{PacketID: 3, Created: 3, Expiry: -1},
|
||||
{PacketID: 4, Created: 4, Expiry: -1},
|
||||
{PacketID: 5, Created: 5},
|
||||
}, cl.State.Inflight.GetAll(false))
|
||||
|
||||
require.Equal(t, []packets.Packet{
|
||||
{PacketID: 3, Created: 3, Expiry: -1},
|
||||
{PacketID: 4, Created: 4, Expiry: -1},
|
||||
}, cl.State.Inflight.GetAll(true))
|
||||
}
|
||||
|
||||
func TestInflightLen(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
}
|
||||
|
||||
func TestInflightClone(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2})
|
||||
require.Equal(t, 1, cl.State.Inflight.Len())
|
||||
|
||||
cloned := cl.State.Inflight.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
require.NotSame(t, cloned, cl.State.Inflight)
|
||||
}
|
||||
|
||||
func TestInflightDelete(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3})
|
||||
require.NotNil(t, cl.State.Inflight.internal[3])
|
||||
|
||||
r := cl.State.Inflight.Delete(3)
|
||||
require.True(t, r)
|
||||
require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID)
|
||||
|
||||
_, ok := cl.State.Inflight.Get(3)
|
||||
require.False(t, ok)
|
||||
|
||||
r = cl.State.Inflight.Delete(3)
|
||||
require.False(t, r)
|
||||
}
|
||||
|
||||
func TestResetReceiveQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
i.ResetReceiveQuota(6)
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
|
||||
func TestReceiveQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
i.receiveQuota = 4
|
||||
i.maximumReceiveQuota = 5
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Return 1
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.IncreaseReceiveQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Reset to max 1
|
||||
i.ResetReceiveQuota(1)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Take 1
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.DecreaseReceiveQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota))
|
||||
}
|
||||
|
||||
func TestResetSendQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
i.ResetSendQuota(6)
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestSendQuota(t *testing.T) {
|
||||
i := NewInflights()
|
||||
i.sendQuota = 4
|
||||
i.maximumSendQuota = 5
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Return 1
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go over max limit
|
||||
i.IncreaseSendQuota()
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Reset to max 1
|
||||
i.ResetSendQuota(1)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Take 1
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
|
||||
// Try to go below zero
|
||||
i.DecreaseSendQuota()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota))
|
||||
}
|
||||
|
||||
func TestNextImmediate(t *testing.T) {
|
||||
cl, _, _ := newTestClient()
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1})
|
||||
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5})
|
||||
|
||||
pk, ok := cl.State.Inflight.NextImmediate()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk)
|
||||
|
||||
r := cl.State.Inflight.Delete(3)
|
||||
require.True(t, r)
|
||||
|
||||
pk, ok = cl.State.Inflight.NextImmediate()
|
||||
require.True(t, ok)
|
||||
require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk)
|
||||
|
||||
r = cl.State.Inflight.Delete(4)
|
||||
require.True(t, r)
|
||||
|
||||
_, ok = cl.State.Inflight.NextImmediate()
|
||||
require.False(t, ok)
|
||||
}
|
||||
118
backend/services/mochi/listeners/http_sysinfo.go
Normal file
118
backend/services/mochi/listeners/http_sysinfo.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint.
|
||||
type HTTPStats struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // the http server
|
||||
log *zerolog.Logger // server logger
|
||||
sysInfo *system.Info // pointers to the server data
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewHTTPStats initialises and returns a new HTTP listener, listening on an address.
|
||||
func NewHTTPStats(id, address string, config *Config, sysInfo *system.Info) *HTTPStats {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
return &HTTPStats{
|
||||
id: id,
|
||||
address: address,
|
||||
sysInfo: sysInfo,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *HTTPStats) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *HTTPStats) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *HTTPStats) Protocol() string {
|
||||
if l.listen != nil && l.listen.TLSConfig != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
return "http"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *HTTPStats) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.jsonHandler)
|
||||
l.listen = &http.Server{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen.TLSConfig = l.config.TLSConfig
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts listening for new connections and serving responses.
|
||||
func (l *HTTPStats) Serve(establish EstablishFn) {
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *HTTPStats) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// jsonHandler is an HTTP handler which outputs the $SYS stats as JSON.
|
||||
func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) {
|
||||
info := *l.sysInfo.Clone()
|
||||
|
||||
out, err := json.MarshalIndent(info, "", "\t")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
}
|
||||
|
||||
w.Write(out)
|
||||
}
|
||||
127
backend/services/mochi/listeners/http_sysinfo_test.go
Normal file
127
backend/services/mochi/listeners/http_sysinfo_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/system"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPStats(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestHTTPStatsID(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestHTTPStatsAddress(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestHTTPStatsProtocol(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, nil, nil)
|
||||
require.Equal(t, "http", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsTLSProtocol(t *testing.T) {
|
||||
l := NewHTTPStats("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, nil)
|
||||
|
||||
l.Init(nil)
|
||||
require.Equal(t, "https", l.Protocol())
|
||||
}
|
||||
|
||||
func TestHTTPStatsInit(t *testing.T) {
|
||||
sysInfo := new(system.Info)
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, l.sysInfo)
|
||||
require.Equal(t, sysInfo, l.sysInfo)
|
||||
require.NotNil(t, l.listen)
|
||||
require.Equal(t, testAddr, l.listen.Addr)
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeAndClose(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
// setup http stats listener
|
||||
l := NewHTTPStats("t1", testAddr, nil, sysInfo)
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// get body from stats address
|
||||
resp, err := http.Get("http://localhost" + testAddr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
// decode body from json and check data
|
||||
v := new(system.Info)
|
||||
err = json.Unmarshal(body, v)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", v.Version)
|
||||
|
||||
// ensure listening is closed
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
|
||||
_, err = http.Get("http://localhost" + testAddr)
|
||||
require.Error(t, err)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestHTTPStatsServeTLSAndClose(t *testing.T) {
|
||||
sysInfo := &system.Info{
|
||||
Version: "test",
|
||||
}
|
||||
|
||||
l := NewHTTPStats("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
}, sysInfo)
|
||||
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
l.Close(MockCloser)
|
||||
}
|
||||
135
backend/services/mochi/listeners/listeners.go
Normal file
135
backend/services/mochi/listeners/listeners.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Config contains configuration values for a listener.
|
||||
type Config struct {
|
||||
// TLSConfig is a tls.Config configuration to be used with the listener.
|
||||
// See examples folder for basic and mutual-tls use.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// EstablishFn is a callback function for establishing new clients.
|
||||
type EstablishFn func(id string, c net.Conn) error
|
||||
|
||||
// CloseFunc is a callback function for closing all listener clients.
|
||||
type CloseFn func(id string)
|
||||
|
||||
// Listener is an interface for network listeners. A network listener listens
|
||||
// for incoming client connections and adds them to the server.
|
||||
type Listener interface {
|
||||
Init(*zerolog.Logger) error // open the network address
|
||||
Serve(EstablishFn) // starting actively listening for new connections
|
||||
ID() string // return the id of the listener
|
||||
Address() string // the address of the listener
|
||||
Protocol() string // the protocol in use by the listener
|
||||
Close(CloseFn) // stop and close the listener
|
||||
}
|
||||
|
||||
// Listeners contains the network listeners for the broker.
|
||||
type Listeners struct {
|
||||
wg sync.WaitGroup // a waitgroup that waits for all listeners to finish.
|
||||
internal map[string]Listener // a map of active listeners.
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// New returns a new instance of Listeners.
|
||||
func New() *Listeners {
|
||||
return &Listeners{
|
||||
internal: map[string]Listener{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new listener to the listeners map, keyed on id.
|
||||
func (l *Listeners) Add(val Listener) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.internal[val.ID()] = val
|
||||
}
|
||||
|
||||
// Get returns the value of a listener if it exists.
|
||||
func (l *Listeners) Get(id string) (Listener, bool) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
val, ok := l.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the length of the listeners map.
|
||||
func (l *Listeners) Len() int {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
return len(l.internal)
|
||||
}
|
||||
|
||||
// Delete removes a listener from the internal map.
|
||||
func (l *Listeners) Delete(id string) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
delete(l.internal, id)
|
||||
}
|
||||
|
||||
// Serve starts a listener serving from the internal map.
|
||||
func (l *Listeners) Serve(id string, establisher EstablishFn) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
listener := l.internal[id]
|
||||
|
||||
go func(e EstablishFn) {
|
||||
defer l.wg.Done()
|
||||
l.wg.Add(1)
|
||||
listener.Serve(e)
|
||||
}(establisher)
|
||||
}
|
||||
|
||||
// ServeAll starts all listeners serving from the internal map.
|
||||
func (l *Listeners) ServeAll(establisher EstablishFn) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Serve(id, establisher)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops a listener from the internal map.
|
||||
func (l *Listeners) Close(id string, closer CloseFn) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
if listener, ok := l.internal[id]; ok {
|
||||
listener.Close(closer)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseAll iterates and closes all registered listeners.
|
||||
func (l *Listeners) CloseAll(closer CloseFn) {
|
||||
l.RLock()
|
||||
i := 0
|
||||
ids := make([]string, len(l.internal))
|
||||
for id := range l.internal {
|
||||
ids[i] = id
|
||||
i++
|
||||
}
|
||||
l.RUnlock()
|
||||
|
||||
for _, id := range ids {
|
||||
l.Close(id, closer)
|
||||
}
|
||||
l.wg.Wait()
|
||||
}
|
||||
177
backend/services/mochi/listeners/listeners_test.go
Normal file
177
backend/services/mochi/listeners/listeners_test.go
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testAddr = ":22222"
|
||||
|
||||
var (
|
||||
logger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.Disabled)
|
||||
|
||||
testCertificate = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB
|
||||
VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV
|
||||
BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD
|
||||
VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x
|
||||
DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3
|
||||
AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi
|
||||
OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI
|
||||
MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD
|
||||
gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ
|
||||
qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy
|
||||
zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o
|
||||
FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA
|
||||
rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB
|
||||
AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K
|
||||
UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m
|
||||
n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ
|
||||
mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6
|
||||
INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z
|
||||
AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt
|
||||
/F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32
|
||||
WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy
|
||||
w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3
|
||||
OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
tlsConfigBasic *tls.Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
cert, err := tls.X509KeyPair(testCertificate, testPrivateKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Basic TLS Config
|
||||
tlsConfigBasic = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
l := New()
|
||||
require.NotNil(t, l.internal)
|
||||
}
|
||||
|
||||
func TestAddListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
}
|
||||
|
||||
func TestGetListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
|
||||
g, ok := l.Get("t1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, g.ID(), "t1")
|
||||
}
|
||||
|
||||
func TestLenListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
require.Contains(t, l.internal, "t2")
|
||||
require.Equal(t, 2, l.Len())
|
||||
}
|
||||
|
||||
func TestDeleteListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
require.Contains(t, l.internal, "t1")
|
||||
l.Delete("t1")
|
||||
_, ok := l.Get("t1")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, l.internal["t1"])
|
||||
}
|
||||
|
||||
func TestServeListener(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
require.False(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func TestServeAllListeners(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
l.Add(NewMockListener("t3", testAddr))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
l.Close("t1", MockCloser)
|
||||
l.Close("t2", MockCloser)
|
||||
l.Close("t3", MockCloser)
|
||||
|
||||
require.False(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.False(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.False(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
}
|
||||
|
||||
func TestCloseListener(t *testing.T) {
|
||||
l := New()
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
l.Add(mocked)
|
||||
l.Serve("t1", MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close("t1", func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.True(t, closed)
|
||||
}
|
||||
|
||||
func TestCloseAllListeners(t *testing.T) {
|
||||
l := New()
|
||||
l.Add(NewMockListener("t1", testAddr))
|
||||
l.Add(NewMockListener("t2", testAddr))
|
||||
l.Add(NewMockListener("t3", testAddr))
|
||||
l.ServeAll(MockEstablisher)
|
||||
time.Sleep(time.Millisecond)
|
||||
require.True(t, l.internal["t1"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t2"].(*MockListener).IsServing())
|
||||
require.True(t, l.internal["t3"].(*MockListener).IsServing())
|
||||
|
||||
closed := make(map[string]bool)
|
||||
l.CloseAll(func(id string) {
|
||||
closed[id] = true
|
||||
})
|
||||
require.Contains(t, closed, "t1")
|
||||
require.Contains(t, closed, "t2")
|
||||
require.Contains(t, closed, "t3")
|
||||
require.True(t, closed["t1"])
|
||||
require.True(t, closed["t2"])
|
||||
require.True(t, closed["t3"])
|
||||
}
|
||||
103
backend/services/mochi/listeners/mock.go
Normal file
103
backend/services/mochi/listeners/mock.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// MockEstablisher is a function signature which can be used in testing.
|
||||
func MockEstablisher(id string, c net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockCloser is a function signature which can be used in testing.
|
||||
func MockCloser(id string) {}
|
||||
|
||||
// MockListener is a mock listener for establishing client connections.
|
||||
type MockListener struct {
|
||||
sync.RWMutex
|
||||
id string // the id of the listener
|
||||
address string // the network address the listener binds to
|
||||
Config *Config // configuration for the listener
|
||||
done chan bool // indicate the listener is done
|
||||
Serving bool // indicate the listener is serving
|
||||
Listening bool // indiciate the listener is listening
|
||||
ErrListen bool // throw an error on listen
|
||||
}
|
||||
|
||||
// NewMockListener returns a new instance of MockListener.
|
||||
func NewMockListener(id, address string) *MockListener {
|
||||
return &MockListener{
|
||||
id: id,
|
||||
address: address,
|
||||
done: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve serves the mock listener.
|
||||
func (l *MockListener) Serve(establisher EstablishFn) {
|
||||
l.Lock()
|
||||
l.Serving = true
|
||||
l.Unlock()
|
||||
|
||||
for range l.done {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *MockListener) Init(log *zerolog.Logger) error {
|
||||
if l.ErrListen {
|
||||
return fmt.Errorf("listen failure")
|
||||
}
|
||||
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Listening = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// ID returns the id of the mock listener.
|
||||
func (l *MockListener) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *MockListener) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *MockListener) Protocol() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
// Close closes the mock listener.
|
||||
func (l *MockListener) Close(closer CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
l.Serving = false
|
||||
closer(l.id)
|
||||
close(l.done)
|
||||
}
|
||||
|
||||
// IsServing indicates whether the mock listener is serving.
|
||||
func (l *MockListener) IsServing() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Serving
|
||||
}
|
||||
|
||||
// IsListening indicates whether the mock listener is listening.
|
||||
func (l *MockListener) IsListening() bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
return l.Listening
|
||||
}
|
||||
99
backend/services/mochi/listeners/mock_test.go
Normal file
99
backend/services/mochi/listeners/mock_test.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMockEstablisher(t *testing.T) {
|
||||
_, w := net.Pipe()
|
||||
err := MockEstablisher("t1", w)
|
||||
require.NoError(t, err)
|
||||
w.Close()
|
||||
}
|
||||
|
||||
func TestNewMockListener(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, testAddr, mocked.address)
|
||||
}
|
||||
func TestMockListenerID(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.ID())
|
||||
}
|
||||
|
||||
func TestMockListenerAddress(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, testAddr, mocked.Address())
|
||||
}
|
||||
func TestMockListenerProtocol(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "mock", mocked.Protocol())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsListening(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerIsServing(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
}
|
||||
|
||||
func TestNewMockListenerInit(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, "t1", mocked.id)
|
||||
require.Equal(t, testAddr, mocked.address)
|
||||
|
||||
require.Equal(t, false, mocked.IsListening())
|
||||
err := mocked.Init(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, mocked.IsListening())
|
||||
}
|
||||
|
||||
func TestNewMockListenerInitFailure(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
mocked.ErrListen = true
|
||||
err := mocked.Init(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMockListenerServe(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
require.Equal(t, false, mocked.IsServing())
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
mocked.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond) // easy non-channel wait for start of serving
|
||||
require.Equal(t, true, mocked.IsServing())
|
||||
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
|
||||
mocked.Init(nil)
|
||||
}
|
||||
|
||||
func TestMockListenerClose(t *testing.T) {
|
||||
mocked := NewMockListener("t1", testAddr)
|
||||
var closed bool
|
||||
mocked.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
88
backend/services/mochi/listeners/net.go
Normal file
88
backend/services/mochi/listeners/net.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Net is a listener for establishing client connections on basic TCP protocol.
|
||||
type Net struct { // [MQTT-4.2.0-1]
|
||||
mu sync.Mutex
|
||||
listener net.Listener // a net.Listener which will listen for new clients
|
||||
id string // the internal id of the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewNet initialises and returns a listener serving incoming connections on the given net.Listener
|
||||
func NewNet(id string, listener net.Listener) *Net {
|
||||
return &Net{
|
||||
id: id,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Net) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Net) Address() string {
|
||||
return l.listener.Addr().String()
|
||||
}
|
||||
|
||||
// Protocol returns the network of the listener.
|
||||
func (l *Net) Protocol() string {
|
||||
return l.listener.Addr().Network()
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Net) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *Net) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Net) Close(closeClients CloseFn) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listener != nil {
|
||||
err := l.listener.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
105
backend/services/mochi/listeners/net_test.go
Normal file
105
backend/services/mochi/listeners/net_test.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewNet(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.id)
|
||||
}
|
||||
|
||||
func TestNetID(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestNetAddress(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, n.Addr().String(), l.Address())
|
||||
}
|
||||
|
||||
func TestNetProtocol(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestNetInit(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestNetServeAndClose(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestNetEstablishThenEnd(t *testing.T) {
|
||||
n, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l := NewNet("t1", n)
|
||||
err = l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", n.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
108
backend/services/mochi/listeners/tcp.go
Normal file
108
backend/services/mochi/listeners/tcp.go
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// TCP is a listener for establishing client connections on basic TCP protocol.
|
||||
type TCP struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
listen net.Listener // a net.Listener which will listen for new clients
|
||||
config *Config // configuration values for the listener
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewTCP initialises and returns a new TCP listener, listening on an address.
|
||||
func NewTCP(id, address string, config *Config) *TCP {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
|
||||
return &TCP{
|
||||
id: id,
|
||||
address: address,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *TCP) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *TCP) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *TCP) Protocol() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *TCP) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
if l.config.TLSConfig != nil {
|
||||
l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig)
|
||||
} else {
|
||||
l.listen, err = net.Listen("tcp", l.address)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new TCP connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *TCP) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *TCP) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
131
backend/services/mochi/listeners/tcp_test.go
Normal file
131
backend/services/mochi/listeners/tcp_test.go
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTCP(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestTCPID(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestTCPAddress(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestTCPProtocol(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPProtocolTLS(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
|
||||
l.Init(&logger)
|
||||
defer l.listen.Close()
|
||||
require.Equal(t, "tcp", l.Protocol())
|
||||
}
|
||||
|
||||
func TestTCPInit(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewTCP("t2", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err = l2.Init(&logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l2.config.TLSConfig)
|
||||
}
|
||||
|
||||
func TestTCPServeAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestTCPServeTLSAndClose(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.Equal(t, true, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestTCPEstablishThenEnd(t *testing.T) {
|
||||
l := NewTCP("t1", testAddr, nil)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("tcp", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
98
backend/services/mochi/listeners/unixsock.go
Normal file
98
backend/services/mochi/listeners/unixsock.go
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
|
||||
type UnixSock struct {
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener.
|
||||
address string // the network address to bind to.
|
||||
listen net.Listener // a net.Listener which will listen for new clients.
|
||||
log *zerolog.Logger // server logger
|
||||
end uint32 // ensure the close methods are only called once.
|
||||
}
|
||||
|
||||
// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
|
||||
func NewUnixSock(id, address string) *UnixSock {
|
||||
return &UnixSock{
|
||||
id: id,
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *UnixSock) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *UnixSock) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *UnixSock) Protocol() string {
|
||||
return "unix"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *UnixSock) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
var err error
|
||||
_ = os.Remove(l.address)
|
||||
l.listen, err = net.Listen("unix", l.address)
|
||||
return err
|
||||
}
|
||||
|
||||
// Serve starts waiting for new UnixSock connections, and calls the establish
|
||||
// connection callback for any received.
|
||||
func (l *UnixSock) Serve(establish EstablishFn) {
|
||||
for {
|
||||
if atomic.LoadUint32(&l.end) == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := l.listen.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&l.end) == 0 {
|
||||
go func() {
|
||||
err = establish(l.id, conn)
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *UnixSock) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
if l.listen != nil {
|
||||
err := l.listen.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
96
backend/services/mochi/listeners/unixsock_test.go
Normal file
96
backend/services/mochi/listeners/unixsock_test.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: jason@zgwit.com
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testUnixAddr = "mochi.sock"
|
||||
|
||||
func TestNewUnixSock(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testUnixAddr, l.address)
|
||||
}
|
||||
|
||||
func TestUnixSockID(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestUnixSockAddress(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, testUnixAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestUnixSockProtocol(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
require.Equal(t, "unix", l.Protocol())
|
||||
}
|
||||
|
||||
func TestUnixSockInit(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
l.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := NewUnixSock("t2", testUnixAddr)
|
||||
err = l2.Init(&logger)
|
||||
l2.Close(MockCloser)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnixSockServeAndClose(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
|
||||
l.Close(MockCloser) // coverage: close closed
|
||||
l.Serve(MockEstablisher) // coverage: serve closed
|
||||
}
|
||||
|
||||
func TestUnixSockEstablishThenEnd(t *testing.T) {
|
||||
l := NewUnixSock("t1", testUnixAddr)
|
||||
err := l.Init(&logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
established := make(chan bool)
|
||||
go func() {
|
||||
l.Serve(func(id string, c net.Conn) error {
|
||||
established <- true
|
||||
return errors.New("ending") // return an error to exit immediately
|
||||
})
|
||||
o <- true
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
net.Dial("unix", l.listen.Addr().String())
|
||||
require.Equal(t, true, <-established)
|
||||
l.Close(MockCloser)
|
||||
<-o
|
||||
}
|
||||
178
backend/services/mochi/listeners/websocket.go
Normal file
178
backend/services/mochi/listeners/websocket.go
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidMessage indicates that a message payload was not valid.
|
||||
ErrInvalidMessage = errors.New("message type not binary")
|
||||
)
|
||||
|
||||
// Websocket is a listener for establishing websocket connections.
|
||||
type Websocket struct { // [MQTT-4.2.0-1]
|
||||
sync.RWMutex
|
||||
id string // the internal id of the listener
|
||||
address string // the network address to bind to
|
||||
config *Config // configuration values for the listener
|
||||
listen *http.Server // an http server for serving websocket connections
|
||||
log *zerolog.Logger // server logger
|
||||
establish EstablishFn // the server's establish connection handler
|
||||
upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection.
|
||||
end uint32 // ensure the close methods are only called once
|
||||
}
|
||||
|
||||
// NewWebsocket initialises and returns a new Websocket listener, listening on an address.
|
||||
func NewWebsocket(id, address string, config *Config) *Websocket {
|
||||
if config == nil {
|
||||
config = new(Config)
|
||||
}
|
||||
|
||||
return &Websocket{
|
||||
id: id,
|
||||
address: address,
|
||||
config: config,
|
||||
upgrader: &websocket.Upgrader{
|
||||
Subprotocols: []string{"mqtt"},
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the id of the listener.
|
||||
func (l *Websocket) ID() string {
|
||||
return l.id
|
||||
}
|
||||
|
||||
// Address returns the address of the listener.
|
||||
func (l *Websocket) Address() string {
|
||||
return l.address
|
||||
}
|
||||
|
||||
// Protocol returns the address of the listener.
|
||||
func (l *Websocket) Protocol() string {
|
||||
if l.config.TLSConfig != nil {
|
||||
return "wss"
|
||||
}
|
||||
|
||||
return "ws"
|
||||
}
|
||||
|
||||
// Init initializes the listener.
|
||||
func (l *Websocket) Init(log *zerolog.Logger) error {
|
||||
l.log = log
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", l.handler)
|
||||
l.listen = &http.Server{
|
||||
Addr: l.address,
|
||||
Handler: mux,
|
||||
TLSConfig: l.config.TLSConfig,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handler upgrades and handles an incoming websocket connection.
|
||||
func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := l.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = l.establish(l.id, &wsConn{c.UnderlyingConn(), c})
|
||||
if err != nil {
|
||||
l.log.Warn().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts waiting for new Websocket connections, and calls the connection
|
||||
// establishment callback for any received.
|
||||
func (l *Websocket) Serve(establish EstablishFn) {
|
||||
l.establish = establish
|
||||
|
||||
if l.listen.TLSConfig != nil {
|
||||
l.listen.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
l.listen.ListenAndServe()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener and any client connections.
|
||||
func (l *Websocket) Close(closeClients CloseFn) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
l.listen.Shutdown(ctx)
|
||||
}
|
||||
|
||||
closeClients(l.id)
|
||||
}
|
||||
|
||||
// wsConn is a websocket connection which satisfies the net.Conn interface.
|
||||
type wsConn struct {
|
||||
net.Conn
|
||||
c *websocket.Conn
|
||||
}
|
||||
|
||||
// Read reads the next span of bytes from the websocket connection and returns the number of bytes read.
|
||||
func (ws *wsConn) Read(p []byte) (int, error) {
|
||||
op, r, err := ws.c.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if op != websocket.BinaryMessage {
|
||||
err = ErrInvalidMessage
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var n, br int
|
||||
for {
|
||||
br, err = r.Read(p[n:])
|
||||
n += br
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes bytes to the websocket connection.
|
||||
func (ws *wsConn) Write(p []byte) (int, error) {
|
||||
err := ws.c.WriteMessage(websocket.BinaryMessage, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close signals the underlying websocket conn to close.
|
||||
func (ws *wsConn) Close() error {
|
||||
return ws.Conn.Close()
|
||||
}
|
||||
114
backend/services/mochi/listeners/websocket_test.go
Normal file
114
backend/services/mochi/listeners/websocket_test.go
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package listeners
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewWebsocket(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.id)
|
||||
require.Equal(t, testAddr, l.address)
|
||||
}
|
||||
|
||||
func TestWebsocketID(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, "t1", l.ID())
|
||||
}
|
||||
|
||||
func TestWebsocketAddress(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, testAddr, l.Address())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocol(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Equal(t, "ws", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsocketProtocoTLS(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
require.Equal(t, "wss", l.Protocol())
|
||||
}
|
||||
|
||||
func TestWebsockeInit(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
require.Nil(t, l.listen)
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, l.listen)
|
||||
}
|
||||
|
||||
func TestWebsocketServeAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
|
||||
require.True(t, closed)
|
||||
<-o
|
||||
}
|
||||
|
||||
func TestWebsocketServeTLSAndClose(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, &Config{
|
||||
TLSConfig: tlsConfigBasic,
|
||||
})
|
||||
err := l.Init(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
o := make(chan bool)
|
||||
go func(o chan bool) {
|
||||
l.Serve(MockEstablisher)
|
||||
o <- true
|
||||
}(o)
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
var closed bool
|
||||
l.Close(func(id string) {
|
||||
closed = true
|
||||
})
|
||||
require.Equal(t, true, closed)
|
||||
}
|
||||
|
||||
func TestWebsocketUpgrade(t *testing.T) {
|
||||
l := NewWebsocket("t1", testAddr, nil)
|
||||
l.Init(nil)
|
||||
|
||||
e := make(chan bool)
|
||||
l.establish = func(id string, c net.Conn) error {
|
||||
e <- true
|
||||
return nil
|
||||
}
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(l.handler))
|
||||
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, <-e)
|
||||
|
||||
s.Close()
|
||||
ws.Close()
|
||||
}
|
||||
172
backend/services/mochi/packets/codec.go
Normal file
172
backend/services/mochi/packets/codec.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// bytesToString provides a zero-alloc no-copy byte to string conversion.
|
||||
// via https://github.com/golang/go/issues/25484#issuecomment-391415660
|
||||
func bytesToString(bs []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&bs))
|
||||
}
|
||||
|
||||
// decodeUint16 extracts the value of two bytes from a byte array.
|
||||
func decodeUint16(buf []byte, offset int) (uint16, int, error) {
|
||||
if len(buf) < offset+2 {
|
||||
return 0, 0, ErrMalformedOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil
|
||||
}
|
||||
|
||||
// decodeUint32 extracts the value of four bytes from a byte array.
|
||||
func decodeUint32(buf []byte, offset int) (uint32, int, error) {
|
||||
if len(buf) < offset+4 {
|
||||
return 0, 0, ErrMalformedOffsetUintOutOfRange
|
||||
}
|
||||
|
||||
return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil
|
||||
}
|
||||
|
||||
// decodeString extracts a string from a byte array, beginning at an offset.
|
||||
func decodeString(buf []byte, offset int) (string, int, error) {
|
||||
b, n, err := decodeBytes(buf, offset)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5]
|
||||
return "", 0, ErrMalformedInvalidUTF8
|
||||
}
|
||||
|
||||
return bytesToString(b), n, nil
|
||||
}
|
||||
|
||||
// validUTF8 checks if the byte array contains valid UTF-8 characters.
|
||||
func validUTF8(b []byte) bool {
|
||||
return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2]
|
||||
}
|
||||
|
||||
// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads.
|
||||
func decodeBytes(buf []byte, offset int) ([]byte, int, error) {
|
||||
length, next, err := decodeUint16(buf, offset)
|
||||
if err != nil {
|
||||
return make([]byte, 0), 0, err
|
||||
}
|
||||
|
||||
if next+int(length) > len(buf) {
|
||||
return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange
|
||||
}
|
||||
|
||||
return buf[next : next+int(length)], next + int(length), nil
|
||||
}
|
||||
|
||||
// decodeByte extracts the value of a byte from a byte array.
|
||||
func decodeByte(buf []byte, offset int) (byte, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return 0, 0, ErrMalformedOffsetByteOutOfRange
|
||||
}
|
||||
return buf[offset], offset + 1, nil
|
||||
}
|
||||
|
||||
// decodeByteBool extracts the value of a byte from a byte array and returns a bool.
|
||||
func decodeByteBool(buf []byte, offset int) (bool, int, error) {
|
||||
if len(buf) <= offset {
|
||||
return false, 0, ErrMalformedOffsetBoolOutOfRange
|
||||
}
|
||||
return 1&buf[offset] > 0, offset + 1, nil
|
||||
}
|
||||
|
||||
// encodeBool returns a byte instead of a bool.
|
||||
func encodeBool(b bool) byte {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads.
|
||||
func encodeBytes(val []byte) []byte {
|
||||
// In most circumstances the number of bytes being encoded is small.
|
||||
// Setting the cap to a low amount allows us to account for those without
|
||||
// triggering allocation growth on append unless we need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, val...)
|
||||
}
|
||||
|
||||
// encodeUint16 encodes a uint16 value to a byte array.
|
||||
func encodeUint16(val uint16) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeUint32 encodes a uint16 value to a byte array.
|
||||
func encodeUint32(val uint32) []byte {
|
||||
buf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(buf, val)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeString encodes a string to a byte array.
|
||||
func encodeString(val string) []byte {
|
||||
// Like encodeBytes, we set the cap to a small number to avoid
|
||||
// triggering allocation growth on append unless we absolutely need to.
|
||||
buf := make([]byte, 2, 32)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(val)))
|
||||
return append(buf, []byte(val)...)
|
||||
}
|
||||
|
||||
// encodeLength writes length bits for the header.
|
||||
func encodeLength(b *bytes.Buffer, length int64) {
|
||||
// 1.5.5 Variable Byte Integer encode non-normative
|
||||
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
|
||||
for {
|
||||
eb := byte(length % 128)
|
||||
length /= 128
|
||||
if length > 0 {
|
||||
eb |= 0x80
|
||||
}
|
||||
b.WriteByte(eb)
|
||||
if length == 0 {
|
||||
break // [MQTT-1.5.5-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeLength(b io.ByteReader) (n, bu int, err error) {
|
||||
// see 1.5.5 Variable Byte Integer decode non-normative
|
||||
// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027
|
||||
var multiplier uint32
|
||||
var value uint32
|
||||
bu = 1
|
||||
for {
|
||||
eb, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, bu, err
|
||||
}
|
||||
|
||||
value |= uint32(eb&127) << multiplier
|
||||
if value > 268435455 {
|
||||
return 0, bu, ErrMalformedVariableByteInteger
|
||||
}
|
||||
|
||||
if (eb & 128) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
multiplier += 7
|
||||
bu++
|
||||
}
|
||||
|
||||
return int(value), bu, nil
|
||||
}
|
||||
422
backend/services/mochi/packets/codec_test.go
Normal file
422
backend/services/mochi/packets/codec_test.go
Normal file
|
|
@ -0,0 +1,422 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
b := []byte{'a', 'b', 'c'}
|
||||
require.Equal(t, "abc", bytesToString(b))
|
||||
}
|
||||
|
||||
func TestDecodeString(t *testing.T) {
|
||||
expect := []struct {
|
||||
name string
|
||||
rawBytes []byte
|
||||
result string
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: "a/b/c/d",
|
||||
},
|
||||
{
|
||||
offset: 14,
|
||||
rawBytes: []byte{
|
||||
Connect << 4, 17, // Fixed header
|
||||
0, 6, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name
|
||||
3, // Protocol Version
|
||||
0, // Packet Flags
|
||||
0, 30, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'h', 'e', 'y', // Client ID "zen"},
|
||||
},
|
||||
result: "hey",
|
||||
},
|
||||
{
|
||||
offset: 2,
|
||||
rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97},
|
||||
result: "1/2/3/4/a/b/c/d/e/^/@/!",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38},
|
||||
result: "x/y/z",
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 5,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 9,
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'},
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 17,
|
||||
rawBytes: []byte{
|
||||
Connect << 4, 0, // Fixed header
|
||||
0, 4, // Protocol Name - MSB+LSB
|
||||
'M', 'Q', 'T', 'T', // Protocol Name
|
||||
4, // Protocol Version
|
||||
0, // Flags
|
||||
0, 20, // Keepalive
|
||||
0, 3, // Client ID - MSB+LSB
|
||||
'z', 'e', 'n', // Client ID "zen"
|
||||
0, 6, // Will Topic - MSB+LSB
|
||||
'l',
|
||||
},
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
offset: 0,
|
||||
rawBytes: []byte{0, 7, 0xc3, 0x28, 98, 47, 99, 47, 100},
|
||||
shouldFail: ErrMalformedInvalidUTF8,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeString(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeStringZeroWidthNoBreak(t *testing.T) { // [MQTT-1.5.4-3]
|
||||
result, _, err := decodeString([]byte{0, 3, 0xEF, 0xBB, 0xBF}, 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "\ufeff", result)
|
||||
}
|
||||
|
||||
func TestDecodeBytes(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result []uint8
|
||||
next int
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // truncated connect packet (clean session)
|
||||
result: []byte{0x4d, 0x51, 0x54, 0x54},
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // truncated connect packet, only checking start
|
||||
result: []byte{0x4d, 0x51, 0x54, 0x54},
|
||||
next: 6,
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 0,
|
||||
shouldFail: ErrMalformedOffsetBytesOutOfRange,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, _, err := decodeBytes(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByte(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint8
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes
|
||||
result: uint8(0x00),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x04),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x4d),
|
||||
offset: 2,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 81, 84, 84},
|
||||
result: uint8(0x51),
|
||||
offset: 3,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 4, 77, 80, 82, 84},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetByteOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByte(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint16(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint16
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x07),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97},
|
||||
result: uint16(0x761),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+2, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUint32(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result uint32
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0, 0, 0, 7, 8},
|
||||
result: uint32(7),
|
||||
offset: 0,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 0, 1, 226, 64, 8},
|
||||
result: uint32(123456),
|
||||
offset: 1,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0, 7, 255, 47},
|
||||
offset: 8,
|
||||
shouldFail: ErrMalformedOffsetUintOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeUint32(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, i+4, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeByteBool(t *testing.T) {
|
||||
expect := []struct {
|
||||
rawBytes []byte
|
||||
result bool
|
||||
offset int
|
||||
shouldFail error
|
||||
}{
|
||||
{
|
||||
rawBytes: []byte{0x00, 0x00},
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
rawBytes: []byte{0x01, 0x00},
|
||||
offset: 5,
|
||||
shouldFail: ErrMalformedOffsetBoolOutOfRange,
|
||||
},
|
||||
}
|
||||
|
||||
for i, wanted := range expect {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset)
|
||||
if wanted.shouldFail != nil {
|
||||
require.True(t, errors.Is(err, wanted.shouldFail), "want %v to be a %v", err, wanted.shouldFail)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.result, result)
|
||||
require.Equal(t, 1, offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeLength(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{0x78})
|
||||
n, bu, err := DecodeLength(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 120, n)
|
||||
require.Equal(t, 1, bu)
|
||||
|
||||
b = bytes.NewBuffer([]byte{255, 255, 255, 127})
|
||||
n, bu, err = DecodeLength(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 268435455, n)
|
||||
require.Equal(t, 4, bu)
|
||||
}
|
||||
|
||||
func TestDecodeLengthErrors(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
_, _, err := DecodeLength(b)
|
||||
require.Error(t, err)
|
||||
|
||||
b = bytes.NewBuffer([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f})
|
||||
_, _, err = DecodeLength(b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrMalformedVariableByteInteger)
|
||||
}
|
||||
|
||||
func TestEncodeBool(t *testing.T) {
|
||||
result := encodeBool(true)
|
||||
require.Equal(t, byte(1), result)
|
||||
|
||||
result = encodeBool(false)
|
||||
require.Equal(t, byte(0), result)
|
||||
|
||||
// Check failure.
|
||||
result = encodeBool(false)
|
||||
require.NotEqual(t, byte(1), result)
|
||||
}
|
||||
|
||||
func TestEncodeBytes(t *testing.T) {
|
||||
result := encodeBytes([]byte("testing"))
|
||||
require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result)
|
||||
|
||||
result = encodeBytes([]byte("testing"))
|
||||
require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result)
|
||||
}
|
||||
|
||||
func TestEncodeUint16(t *testing.T) {
|
||||
result := encodeUint16(0)
|
||||
require.Equal(t, []byte{0x00, 0x00}, result)
|
||||
|
||||
result = encodeUint16(32767)
|
||||
require.Equal(t, []byte{0x7f, 0xff}, result)
|
||||
|
||||
result = encodeUint16(math.MaxUint16)
|
||||
require.Equal(t, []byte{0xff, 0xff}, result)
|
||||
}
|
||||
|
||||
func TestEncodeUint32(t *testing.T) {
|
||||
result := encodeUint32(7)
|
||||
require.Equal(t, []byte{0x00, 0x00, 0x00, 0x07}, result)
|
||||
|
||||
result = encodeUint32(32767)
|
||||
require.Equal(t, []byte{0, 0, 127, 255}, result)
|
||||
|
||||
result = encodeUint32(math.MaxUint32)
|
||||
require.Equal(t, []byte{255, 255, 255, 255}, result)
|
||||
}
|
||||
|
||||
func TestEncodeString(t *testing.T) {
|
||||
result := encodeString("testing")
|
||||
require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result)
|
||||
|
||||
result = encodeString("")
|
||||
require.Equal(t, []uint8{0x00, 0x00}, result)
|
||||
|
||||
result = encodeString("a")
|
||||
require.Equal(t, []uint8{0x00, 0x01, 0x61}, result)
|
||||
|
||||
result = encodeString("b")
|
||||
require.NotEqual(t, []uint8{0x00, 0x00}, result)
|
||||
}
|
||||
|
||||
func TestEncodeLength(t *testing.T) {
|
||||
b := new(bytes.Buffer)
|
||||
encodeLength(b, 120)
|
||||
require.Equal(t, []byte{0x78}, b.Bytes())
|
||||
|
||||
b = new(bytes.Buffer)
|
||||
encodeLength(b, math.MaxInt64)
|
||||
require.Equal(t, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestValidUTF8(t *testing.T) {
|
||||
require.True(t, validUTF8([]byte{0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}))
|
||||
require.False(t, validUTF8([]byte{0xff, 0xff}))
|
||||
require.False(t, validUTF8([]byte{0x74, 0x00, 0x73, 0x74}))
|
||||
}
|
||||
147
backend/services/mochi/packets/codes.go
Normal file
147
backend/services/mochi/packets/codes.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
// Code contains a reason code and reason string for a response.
|
||||
type Code struct {
|
||||
Reason string
|
||||
Code byte
|
||||
}
|
||||
|
||||
// String returns the readable reason for a code.
|
||||
func (c Code) String() string {
|
||||
return c.Reason
|
||||
}
|
||||
|
||||
// Error returns the readable reason for a code.
|
||||
func (c Code) Error() string {
|
||||
return c.Reason
|
||||
}
|
||||
|
||||
var (
|
||||
// QosCodes indicicates the reason codes for each Qos byte.
|
||||
QosCodes = map[byte]Code{
|
||||
0: CodeGrantedQos0,
|
||||
1: CodeGrantedQos1,
|
||||
2: CodeGrantedQos2,
|
||||
}
|
||||
|
||||
CodeSuccess = Code{Code: 0x00, Reason: "success"}
|
||||
CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"}
|
||||
CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"}
|
||||
CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"}
|
||||
CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"}
|
||||
CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"}
|
||||
CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"}
|
||||
CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"}
|
||||
CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"}
|
||||
CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"}
|
||||
ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"}
|
||||
ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"}
|
||||
ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"}
|
||||
ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"}
|
||||
ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"}
|
||||
ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"}
|
||||
ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"}
|
||||
ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"}
|
||||
ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"}
|
||||
ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"}
|
||||
ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"}
|
||||
ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"}
|
||||
ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"}
|
||||
ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"}
|
||||
ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"}
|
||||
ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"}
|
||||
ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"}
|
||||
ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"}
|
||||
ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"}
|
||||
ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"}
|
||||
ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"}
|
||||
ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"}
|
||||
ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"}
|
||||
ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"}
|
||||
ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"}
|
||||
ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"}
|
||||
ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"}
|
||||
ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"}
|
||||
ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"}
|
||||
ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"}
|
||||
ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
|
||||
ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"}
|
||||
ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"}
|
||||
ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"}
|
||||
ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"}
|
||||
ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"}
|
||||
ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"}
|
||||
ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"}
|
||||
ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"}
|
||||
ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"}
|
||||
ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"}
|
||||
ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"}
|
||||
ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"}
|
||||
ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"}
|
||||
ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"}
|
||||
ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"}
|
||||
ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"}
|
||||
ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"}
|
||||
ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"}
|
||||
ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"}
|
||||
ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"}
|
||||
ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"}
|
||||
ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"}
|
||||
ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"}
|
||||
ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"}
|
||||
ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"}
|
||||
ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"}
|
||||
ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"}
|
||||
ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"}
|
||||
ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"}
|
||||
ErrServerBusy = Code{Code: 0x89, Reason: "server busy"}
|
||||
ErrBanned = Code{Code: 0x8A, Reason: "banned"}
|
||||
ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"}
|
||||
ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"}
|
||||
ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"}
|
||||
ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"}
|
||||
ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"}
|
||||
ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"}
|
||||
ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"}
|
||||
ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"}
|
||||
ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"}
|
||||
ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"}
|
||||
ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"}
|
||||
ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"}
|
||||
ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"}
|
||||
ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"}
|
||||
ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"}
|
||||
ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"}
|
||||
ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"}
|
||||
ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"}
|
||||
ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"}
|
||||
ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"}
|
||||
ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptiptions not supported"}
|
||||
ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"}
|
||||
ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"}
|
||||
ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"}
|
||||
ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"}
|
||||
|
||||
// MQTTv3 specific bytes.
|
||||
Err3UnsupportedProtocolVersion = Code{Code: 0x01}
|
||||
Err3ClientIdentifierNotValid = Code{Code: 0x02}
|
||||
Err3ServerUnavailable = Code{Code: 0x03}
|
||||
ErrMalformedUsernameOrPassword = Code{Code: 0x04}
|
||||
Err3NotAuthorized = Code{Code: 0x05}
|
||||
|
||||
// V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes.
|
||||
// This is required because MQTTv3 has different return byte specification.
|
||||
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257
|
||||
V5CodesToV3 = map[Code]Code{
|
||||
ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion,
|
||||
ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid,
|
||||
ErrServerUnavailable: Err3ServerUnavailable,
|
||||
ErrMalformedUsername: ErrMalformedUsernameOrPassword,
|
||||
ErrMalformedPassword: ErrMalformedUsernameOrPassword,
|
||||
ErrBadUsernameOrPassword: Err3NotAuthorized,
|
||||
}
|
||||
)
|
||||
29
backend/services/mochi/packets/codes_test.go
Normal file
29
backend/services/mochi/packets/codes_test.go
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCodesString(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "test",
|
||||
Code: 0x1,
|
||||
}
|
||||
|
||||
require.Equal(t, "test", c.String())
|
||||
}
|
||||
|
||||
func TestCodesErrorr(t *testing.T) {
|
||||
c := Code{
|
||||
Reason: "error",
|
||||
Code: 0x1,
|
||||
}
|
||||
|
||||
require.Equal(t, "error", error(c).Error())
|
||||
}
|
||||
63
backend/services/mochi/packets/fixedheader.go
Normal file
63
backend/services/mochi/packets/fixedheader.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// FixedHeader contains the values of the fixed header portion of the MQTT packet.
|
||||
type FixedHeader struct {
|
||||
Remaining int `json:"remaining"` // the number of remaining bytes in the payload.
|
||||
Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1).
|
||||
Qos byte `json:"qos"` // indicates the quality of service expected.
|
||||
Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time.
|
||||
Retain bool `json:"retain"` // whether the message should be retained.
|
||||
}
|
||||
|
||||
// Encode encodes the FixedHeader and returns a bytes buffer.
|
||||
func (fh *FixedHeader) Encode(buf *bytes.Buffer) {
|
||||
buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain))
|
||||
encodeLength(buf, int64(fh.Remaining))
|
||||
}
|
||||
|
||||
// Decode extracts the specification bits from the header byte.
|
||||
func (fh *FixedHeader) Decode(hb byte) error {
|
||||
fh.Type = hb >> 4 // Get the message type from the first 4 bytes.
|
||||
|
||||
switch fh.Type {
|
||||
case Publish:
|
||||
if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 {
|
||||
return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4]
|
||||
}
|
||||
|
||||
fh.Dup = (hb>>3)&0x01 > 0 // is duplicate
|
||||
fh.Qos = (hb >> 1) & 0x03 // qos flag
|
||||
fh.Retain = hb&0x01 > 0 // is retain flag
|
||||
case Pubrel:
|
||||
fallthrough
|
||||
case Subscribe:
|
||||
fallthrough
|
||||
case Unsubscribe:
|
||||
if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1]
|
||||
return ErrMalformedFlags
|
||||
}
|
||||
|
||||
fh.Qos = (hb >> 1) & 0x03
|
||||
default:
|
||||
if (hb>>0)&0x01 != 0 ||
|
||||
(hb>>1)&0x01 != 0 ||
|
||||
(hb>>2)&0x01 != 0 ||
|
||||
(hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1]
|
||||
return ErrMalformedFlags
|
||||
}
|
||||
}
|
||||
|
||||
if fh.Qos == 0 && fh.Dup {
|
||||
return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
237
backend/services/mochi/packets/fixedheader_test.go
Normal file
237
backend/services/mochi/packets/fixedheader_test.go
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fixedHeaderTable struct {
|
||||
desc string
|
||||
rawBytes []byte
|
||||
header FixedHeader
|
||||
packetError bool
|
||||
expect error
|
||||
}
|
||||
|
||||
var fixedHeaderExpected = []fixedHeaderTable{
|
||||
{
|
||||
desc: "connect",
|
||||
rawBytes: []byte{Connect << 4, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "connack",
|
||||
rawBytes: []byte{Connack << 4, 0x00},
|
||||
header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish",
|
||||
rawBytes: []byte{Publish << 4, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 1",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 1 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 2",
|
||||
rawBytes: []byte{Publish<<4 | 2<<1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish qos 2 retain",
|
||||
rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 0",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
expect: ErrProtocolViolationDupNoQos,
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 0 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0},
|
||||
expect: ErrProtocolViolationDupNoQos,
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 1 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "publish dup qos 2 retain",
|
||||
rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00},
|
||||
header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "puback",
|
||||
rawBytes: []byte{Puback << 4, 0x00},
|
||||
header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubrec",
|
||||
rawBytes: []byte{Pubrec << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubrel",
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pubcomp",
|
||||
rawBytes: []byte{Pubcomp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "subscribe",
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "suback",
|
||||
rawBytes: []byte{Suback << 4, 0x00},
|
||||
header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "unsubscribe",
|
||||
rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "unsuback",
|
||||
rawBytes: []byte{Unsuback << 4, 0x00},
|
||||
header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pingreq",
|
||||
rawBytes: []byte{Pingreq << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "pingresp",
|
||||
rawBytes: []byte{Pingresp << 4, 0x00},
|
||||
header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "disconnect",
|
||||
rawBytes: []byte{Disconnect << 4, 0x00},
|
||||
header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
{
|
||||
desc: "auth",
|
||||
rawBytes: []byte{Auth << 4, 0x00},
|
||||
header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0},
|
||||
},
|
||||
|
||||
// remaining length
|
||||
{
|
||||
desc: "remaining length 10",
|
||||
rawBytes: []byte{Publish << 4, 0x0a},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 512",
|
||||
rawBytes: []byte{Publish << 4, 0x80, 0x04},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 978",
|
||||
rawBytes: []byte{Publish << 4, 0xd2, 0x07},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978},
|
||||
},
|
||||
{
|
||||
desc: "remaining length 20202",
|
||||
rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102},
|
||||
},
|
||||
{
|
||||
desc: "remaining length oversize",
|
||||
rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01},
|
||||
header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333},
|
||||
packetError: true,
|
||||
},
|
||||
|
||||
// Invalid flags for packet
|
||||
{
|
||||
desc: "invalid type dup is true",
|
||||
rawBytes: []byte{Connect<<4 | 1<<3, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid type qos is 1",
|
||||
rawBytes: []byte{Connect<<4 | 1<<1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid type retain is true",
|
||||
rawBytes: []byte{Connect<<4 | 1, 0x00},
|
||||
header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid publish qos bits 1 + 2 set",
|
||||
rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00},
|
||||
header: FixedHeader{Type: Publish},
|
||||
expect: ErrProtocolViolationQosOutOfRange,
|
||||
},
|
||||
{
|
||||
desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0",
|
||||
rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00},
|
||||
header: FixedHeader{Type: Pubrel, Qos: 1},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
{
|
||||
desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0",
|
||||
rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00},
|
||||
header: FixedHeader{Type: Subscribe, Qos: 1},
|
||||
expect: ErrMalformedFlags,
|
||||
},
|
||||
}
|
||||
|
||||
func TestFixedHeaderEncode(t *testing.T) {
|
||||
for _, wanted := range fixedHeaderExpected {
|
||||
t.Run(wanted.desc, func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
wanted.header.Encode(buf)
|
||||
if wanted.expect == nil {
|
||||
require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()))
|
||||
require.EqualValues(t, wanted.rawBytes, buf.Bytes())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFixedHeaderDecode(t *testing.T) {
|
||||
for _, wanted := range fixedHeaderExpected {
|
||||
t.Run(wanted.desc, func(t *testing.T) {
|
||||
fh := new(FixedHeader)
|
||||
err := fh.Decode(wanted.rawBytes[0])
|
||||
if wanted.expect != nil {
|
||||
require.Equal(t, wanted.expect, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wanted.header.Type, fh.Type)
|
||||
require.Equal(t, wanted.header.Dup, fh.Dup)
|
||||
require.Equal(t, wanted.header.Qos, fh.Qos)
|
||||
require.Equal(t, wanted.header.Retain, fh.Retain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1145
backend/services/mochi/packets/packets.go
Normal file
1145
backend/services/mochi/packets/packets.go
Normal file
File diff suppressed because it is too large
Load Diff
505
backend/services/mochi/packets/packets_test.go
Normal file
505
backend/services/mochi/packets/packets_test.go
Normal file
|
|
@ -0,0 +1,505 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const pkInfo = "packet type %v, %s"
|
||||
|
||||
var packetList = []byte{
|
||||
Connect,
|
||||
Connack,
|
||||
Publish,
|
||||
Puback,
|
||||
Pubrec,
|
||||
Pubrel,
|
||||
Pubcomp,
|
||||
Subscribe,
|
||||
Suback,
|
||||
Unsubscribe,
|
||||
Unsuback,
|
||||
Pingreq,
|
||||
Pingresp,
|
||||
Disconnect,
|
||||
Auth,
|
||||
}
|
||||
|
||||
var pkTable = []TPacketCase{
|
||||
TPacketData[Connect].Get(TConnectMqtt311),
|
||||
TPacketData[Connect].Get(TConnectMqtt5),
|
||||
TPacketData[Connect].Get(TConnectUserPassLWT),
|
||||
TPacketData[Connack].Get(TConnackAcceptedMqtt5),
|
||||
TPacketData[Connack].Get(TConnackAcceptedNoSession),
|
||||
TPacketData[Publish].Get(TPublishBasic),
|
||||
TPacketData[Publish].Get(TPublishMqtt5),
|
||||
TPacketData[Puback].Get(TPuback),
|
||||
TPacketData[Pubrec].Get(TPubrec),
|
||||
TPacketData[Pubrel].Get(TPubrel),
|
||||
TPacketData[Pubcomp].Get(TPubcomp),
|
||||
TPacketData[Subscribe].Get(TSubscribe),
|
||||
TPacketData[Subscribe].Get(TSubscribeMqtt5),
|
||||
TPacketData[Suback].Get(TSuback),
|
||||
TPacketData[Unsubscribe].Get(TUnsubscribe),
|
||||
TPacketData[Unsubscribe].Get(TUnsubscribeMqtt5),
|
||||
TPacketData[Pingreq].Get(TPingreq),
|
||||
TPacketData[Pingresp].Get(TPingresp),
|
||||
TPacketData[Disconnect].Get(TDisconnect),
|
||||
TPacketData[Disconnect].Get(TDisconnectMqtt5),
|
||||
}
|
||||
|
||||
func TestNewPackets(t *testing.T) {
|
||||
s := NewPackets()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestPacketsAdd(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
}
|
||||
|
||||
func TestPacketsGet(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
|
||||
pk, ok := s.Get("cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "a1", pk.TopicName)
|
||||
}
|
||||
|
||||
func TestPacketsGetAll(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
s.Add("cl3", Packet{TopicName: "a3"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Contains(t, s.internal, "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 3)
|
||||
}
|
||||
|
||||
func TestPacketsLen(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
s.Add("cl2", Packet{TopicName: "a2"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Equal(t, 2, s.Len())
|
||||
}
|
||||
|
||||
func TestSPacketsDelete(t *testing.T) {
|
||||
s := NewPackets()
|
||||
s.Add("cl1", Packet{TopicName: "a1"})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
|
||||
s.Delete("cl1")
|
||||
_, ok := s.Get("cl1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestFormatPacketID(t *testing.T) {
|
||||
for _, id := range []uint16{0, 7, 0x100, 0xffff} {
|
||||
packet := &Packet{PacketID: id}
|
||||
require.Equal(t, fmt.Sprint(id), packet.FormatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscriptionOptionsEncodeDecode(t *testing.T) {
|
||||
p := &Subscription{
|
||||
Qos: 2,
|
||||
NoLocal: true,
|
||||
RetainAsPublished: true,
|
||||
RetainHandling: 2,
|
||||
}
|
||||
x := new(Subscription)
|
||||
x.decode(p.encode())
|
||||
require.Equal(t, *p, *x)
|
||||
|
||||
p = &Subscription{
|
||||
Qos: 1,
|
||||
NoLocal: false,
|
||||
RetainAsPublished: false,
|
||||
RetainHandling: 1,
|
||||
}
|
||||
x = new(Subscription)
|
||||
x.decode(p.encode())
|
||||
require.Equal(t, *p, *x)
|
||||
}
|
||||
|
||||
func TestPacketEncode(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if !encodeTestOK(wanted) {
|
||||
return
|
||||
}
|
||||
|
||||
pk := new(Packet)
|
||||
copier.Copy(pk, wanted.Packet)
|
||||
require.Equal(t, pkt, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectEncode(buf)
|
||||
case Connack:
|
||||
err = pk.ConnackEncode(buf)
|
||||
case Publish:
|
||||
err = pk.PublishEncode(buf)
|
||||
case Puback:
|
||||
err = pk.PubackEncode(buf)
|
||||
case Pubrec:
|
||||
err = pk.PubrecEncode(buf)
|
||||
case Pubrel:
|
||||
err = pk.PubrelEncode(buf)
|
||||
case Pubcomp:
|
||||
err = pk.PubcompEncode(buf)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeEncode(buf)
|
||||
case Suback:
|
||||
err = pk.SubackEncode(buf)
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeEncode(buf)
|
||||
case Unsuback:
|
||||
err = pk.UnsubackEncode(buf)
|
||||
case Pingreq:
|
||||
err = pk.PingreqEncode(buf)
|
||||
case Pingresp:
|
||||
err = pk.PingrespEncode(buf)
|
||||
case Disconnect:
|
||||
err = pk.DisconnectEncode(buf)
|
||||
case Auth:
|
||||
err = pk.AuthEncode(buf)
|
||||
}
|
||||
if wanted.Expect != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
|
||||
encoded := buf.Bytes()
|
||||
|
||||
// If ActualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc).
|
||||
if len(wanted.ActualBytes) > 0 {
|
||||
wanted.RawBytes = wanted.ActualBytes
|
||||
}
|
||||
require.EqualValues(t, wanted.RawBytes, encoded, pkInfo, pkt, wanted.Desc)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketDecode(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if !decodeTestOK(wanted) {
|
||||
return
|
||||
}
|
||||
|
||||
pk := &Packet{FixedHeader: FixedHeader{Type: pkt}}
|
||||
pk.Mods.AllowResponseInfo = true
|
||||
pk.FixedHeader.Decode(wanted.RawBytes[0])
|
||||
if len(wanted.RawBytes) > 0 {
|
||||
pk.FixedHeader.Remaining = int(wanted.RawBytes[1])
|
||||
}
|
||||
|
||||
if wanted.Packet != nil && wanted.Packet.ProtocolVersion != 0 {
|
||||
pk.ProtocolVersion = wanted.Packet.ProtocolVersion
|
||||
}
|
||||
|
||||
buf := wanted.RawBytes[2:]
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectDecode(buf)
|
||||
case Connack:
|
||||
err = pk.ConnackDecode(buf)
|
||||
case Publish:
|
||||
err = pk.PublishDecode(buf)
|
||||
case Puback:
|
||||
err = pk.PubackDecode(buf)
|
||||
case Pubrec:
|
||||
err = pk.PubrecDecode(buf)
|
||||
case Pubrel:
|
||||
err = pk.PubrelDecode(buf)
|
||||
case Pubcomp:
|
||||
err = pk.PubcompDecode(buf)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeDecode(buf)
|
||||
case Suback:
|
||||
err = pk.SubackDecode(buf)
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeDecode(buf)
|
||||
case Unsuback:
|
||||
err = pk.UnsubackDecode(buf)
|
||||
case Pingreq:
|
||||
err = pk.PingreqDecode(buf)
|
||||
case Pingresp:
|
||||
err = pk.PingrespDecode(buf)
|
||||
case Disconnect:
|
||||
err = pk.DisconnectDecode(buf)
|
||||
case Auth:
|
||||
err = pk.AuthDecode(buf)
|
||||
}
|
||||
|
||||
if wanted.FailFirst != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
require.ErrorIs(t, err, wanted.FailFirst, pkInfo, pkt, wanted.Desc)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.EqualValues(t, wanted.Packet.Filters, pk.Filters, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Type, pk.FixedHeader.Type, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Dup, pk.FixedHeader.Dup, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Qos, pk.FixedHeader.Qos, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.FixedHeader.Retain, pk.FixedHeader.Retain, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
if pkt == Connect {
|
||||
// we use ProtocolVersion for controlling packet encoding, but we don't need to test
|
||||
// against it unless it's a connect packet.
|
||||
require.Equal(t, wanted.Packet.ProtocolVersion, pk.ProtocolVersion, pkInfo, pkt, wanted.Desc)
|
||||
}
|
||||
require.Equal(t, wanted.Packet.Connect.ProtocolName, pk.Connect.ProtocolName, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Clean, pk.Connect.Clean, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.ClientIdentifier, pk.Connect.ClientIdentifier, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Keepalive, pk.Connect.Keepalive, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.Connect.UsernameFlag, pk.Connect.UsernameFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Username, pk.Connect.Username, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.PasswordFlag, pk.Connect.PasswordFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.Password, pk.Connect.Password, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.Connect.WillFlag, pk.Connect.WillFlag, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillTopic, pk.Connect.WillTopic, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillPayload, pk.Connect.WillPayload, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillQos, pk.Connect.WillQos, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.Connect.WillRetain, pk.Connect.WillRetain, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.Equal(t, wanted.Packet.ReasonCodes, pk.ReasonCodes, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.ReasonCode, pk.ReasonCode, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.SessionPresent, pk.SessionPresent, pkInfo, pkt, wanted.Desc)
|
||||
require.Equal(t, wanted.Packet.PacketID, pk.PacketID, pkInfo, pkt, wanted.Desc)
|
||||
|
||||
require.EqualValues(t, wanted.Packet.Properties, pk.Properties)
|
||||
require.EqualValues(t, wanted.Packet.Connect.WillProperties, pk.Connect.WillProperties)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
for _, pkt := range packetList {
|
||||
require.Contains(t, TPacketData, pkt)
|
||||
for _, wanted := range TPacketData[pkt] {
|
||||
t.Run(fmt.Sprintf("%s %s", PacketNames[pkt], wanted.Desc), func(t *testing.T) {
|
||||
if wanted.Group == "validate" || wanted.Primary {
|
||||
pk := wanted.Packet
|
||||
var err error
|
||||
switch pkt {
|
||||
case Connect:
|
||||
err = pk.ConnectValidate()
|
||||
case Publish:
|
||||
err = pk.PublishValidate(1024)
|
||||
case Subscribe:
|
||||
err = pk.SubscribeValidate()
|
||||
case Unsubscribe:
|
||||
err = pk.UnsubscribeValidate()
|
||||
case Auth:
|
||||
err = pk.AuthValidate()
|
||||
}
|
||||
|
||||
if wanted.Expect != nil {
|
||||
require.Error(t, err, pkInfo, pkt, wanted.Desc)
|
||||
require.ErrorIs(t, wanted.Expect, err, pkInfo, pkt, wanted.Desc)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckValidatePubrec(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
CodeNoMatchingSubscribers.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicNameInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
ErrQuotaExceeded.Code,
|
||||
ErrPayloadFormatInvalid.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrec}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidatePubrel(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
ErrPacketIdentifierNotFound.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidatePubcomp(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
ErrPacketIdentifierNotFound.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubcomp}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Pubrel}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidateSuback(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeGrantedQos0.Code,
|
||||
CodeGrantedQos1.Code,
|
||||
CodeGrantedQos2.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicFilterInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
ErrQuotaExceeded.Code,
|
||||
ErrSharedSubscriptionsNotSupported.Code,
|
||||
ErrSubscriptionIdentifiersNotSupported.Code,
|
||||
ErrWildcardSubscriptionsNotSupported.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Suback}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestAckValidateUnsuback(t *testing.T) {
|
||||
for _, b := range []byte{
|
||||
CodeSuccess.Code,
|
||||
CodeNoSubscriptionExisted.Code,
|
||||
ErrUnspecifiedError.Code,
|
||||
ErrImplementationSpecificError.Code,
|
||||
ErrNotAuthorized.Code,
|
||||
ErrTopicFilterInvalid.Code,
|
||||
ErrPacketIdentifierInUse.Code,
|
||||
} {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: b}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Unsuback}, ReasonCode: ErrClientIdentifierTooLong.Code}
|
||||
require.False(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestReasonCodeValidMisc(t *testing.T) {
|
||||
pk := Packet{FixedHeader: FixedHeader{Type: Connack}, ReasonCode: CodeSuccess.Code}
|
||||
require.True(t, pk.ReasonCodeValid())
|
||||
}
|
||||
|
||||
func TestCopy(t *testing.T) {
|
||||
for _, tt := range pkTable {
|
||||
pkc := tt.Packet.Copy(true)
|
||||
|
||||
require.Equal(t, tt.Packet.FixedHeader.Qos, pkc.FixedHeader.Qos, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, false, pkc.FixedHeader.Dup, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, false, pkc.FixedHeader.Retain, pkInfo, tt.Case, tt.Desc)
|
||||
|
||||
require.Equal(t, tt.Packet.TopicName, pkc.TopicName, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.ClientIdentifier, pkc.Connect.ClientIdentifier, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Keepalive, pkc.Connect.Keepalive, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ProtocolVersion, pkc.ProtocolVersion, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.PasswordFlag, pkc.Connect.PasswordFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.UsernameFlag, pkc.Connect.UsernameFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillQos, pkc.Connect.WillQos, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillTopic, pkc.Connect.WillTopic, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillFlag, pkc.Connect.WillFlag, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillRetain, pkc.Connect.WillRetain, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillProperties, pkc.Connect.WillProperties, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Properties, pkc.Properties, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Clean, pkc.Connect.Clean, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.SessionPresent, pkc.SessionPresent, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ReasonCode, pkc.ReasonCode, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.PacketID, pkc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Filters, pkc.Filters, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Payload, pkc.Payload, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Password, pkc.Connect.Password, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.Username, pkc.Connect.Username, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.ProtocolName, pkc.Connect.ProtocolName, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Connect.WillPayload, pkc.Connect.WillPayload, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.ReasonCodes, pkc.ReasonCodes, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Created, pkc.Created, pkInfo, tt.Case, tt.Desc)
|
||||
require.Equal(t, tt.Packet.Origin, pkc.Origin, pkInfo, tt.Case, tt.Desc)
|
||||
require.EqualValues(t, pkc.Properties, tt.Packet.Properties)
|
||||
|
||||
pkcc := tt.Packet.Copy(false)
|
||||
require.Equal(t, uint16(0), pkcc.PacketID, pkInfo, tt.Case, tt.Desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSubscription(t *testing.T) {
|
||||
sub := Subscription{
|
||||
Filter: "a/b/c",
|
||||
RetainHandling: 0,
|
||||
Qos: 0,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: false,
|
||||
Identifier: 1,
|
||||
}
|
||||
|
||||
sub2 := Subscription{
|
||||
Filter: "a/b/d",
|
||||
RetainHandling: 0,
|
||||
Qos: 2,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: true,
|
||||
Identifier: 2,
|
||||
}
|
||||
|
||||
expect := Subscription{
|
||||
Filter: "a/b/c",
|
||||
RetainHandling: 0,
|
||||
Qos: 2,
|
||||
RetainAsPublished: false,
|
||||
NoLocal: true,
|
||||
Identifier: 1,
|
||||
Identifiers: map[string]int{
|
||||
"a/b/c": 1,
|
||||
"a/b/d": 2,
|
||||
},
|
||||
}
|
||||
require.Equal(t, expect, sub.Merge(sub2))
|
||||
}
|
||||
479
backend/services/mochi/packets/properties.go
Normal file
479
backend/services/mochi/packets/properties.go
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
PropPayloadFormat byte = 1
|
||||
PropMessageExpiryInterval byte = 2
|
||||
PropContentType byte = 3
|
||||
PropResponseTopic byte = 8
|
||||
PropCorrelationData byte = 9
|
||||
PropSubscriptionIdentifier byte = 11
|
||||
PropSessionExpiryInterval byte = 17
|
||||
PropAssignedClientID byte = 18
|
||||
PropServerKeepAlive byte = 19
|
||||
PropAuthenticationMethod byte = 21
|
||||
PropAuthenticationData byte = 22
|
||||
PropRequestProblemInfo byte = 23
|
||||
PropWillDelayInterval byte = 24
|
||||
PropRequestResponseInfo byte = 25
|
||||
PropResponseInfo byte = 26
|
||||
PropServerReference byte = 28
|
||||
PropReasonString byte = 31
|
||||
PropReceiveMaximum byte = 33
|
||||
PropTopicAliasMaximum byte = 34
|
||||
PropTopicAlias byte = 35
|
||||
PropMaximumQos byte = 36
|
||||
PropRetainAvailable byte = 37
|
||||
PropUser byte = 38
|
||||
PropMaximumPacketSize byte = 39
|
||||
PropWildcardSubAvailable byte = 40
|
||||
PropSubIDAvailable byte = 41
|
||||
PropSharedSubAvailable byte = 42
|
||||
)
|
||||
|
||||
// validPacketProperties indicates which properties are valid for which packet types.
|
||||
var validPacketProperties = map[byte]map[byte]byte{
|
||||
PropPayloadFormat: {Publish: 1},
|
||||
PropMessageExpiryInterval: {Publish: 1},
|
||||
PropContentType: {Publish: 1},
|
||||
PropResponseTopic: {Publish: 1},
|
||||
PropCorrelationData: {Publish: 1},
|
||||
PropSubscriptionIdentifier: {Publish: 1, Subscribe: 1},
|
||||
PropSessionExpiryInterval: {Connect: 1, Connack: 1, Disconnect: 1},
|
||||
PropAssignedClientID: {Connack: 1},
|
||||
PropServerKeepAlive: {Connack: 1},
|
||||
PropAuthenticationMethod: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropAuthenticationData: {Connect: 1, Connack: 1, Auth: 1},
|
||||
PropRequestProblemInfo: {Connect: 1},
|
||||
PropWillDelayInterval: {Connect: 1},
|
||||
PropRequestResponseInfo: {Connect: 1},
|
||||
PropResponseInfo: {Connack: 1},
|
||||
PropServerReference: {Connack: 1, Disconnect: 1},
|
||||
PropReasonString: {Connack: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Suback: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropReceiveMaximum: {Connect: 1, Connack: 1},
|
||||
PropTopicAliasMaximum: {Connect: 1, Connack: 1},
|
||||
PropTopicAlias: {Publish: 1},
|
||||
PropMaximumQos: {Connack: 1},
|
||||
PropRetainAvailable: {Connack: 1},
|
||||
PropUser: {Connect: 1, Connack: 1, Publish: 1, Puback: 1, Pubrec: 1, Pubrel: 1, Pubcomp: 1, Subscribe: 1, Suback: 1, Unsubscribe: 1, Unsuback: 1, Disconnect: 1, Auth: 1},
|
||||
PropMaximumPacketSize: {Connect: 1, Connack: 1},
|
||||
PropWildcardSubAvailable: {Connack: 1},
|
||||
PropSubIDAvailable: {Connack: 1},
|
||||
PropSharedSubAvailable: {Connack: 1},
|
||||
}
|
||||
|
||||
// UserProperty is an arbitrary key-value pair for a packet user properties array.
|
||||
type UserProperty struct { // [MQTT-1.5.7-1]
|
||||
Key string `json:"k"`
|
||||
Val string `json:"v"`
|
||||
}
|
||||
|
||||
// Properties contains all of the mqtt v5 properties available for a packet.
|
||||
// Some properties have valid values of 0 or not-present. In this case, we opt for
|
||||
// property flags to indicate the usage of property.
|
||||
// Refer to mqtt v5 2.2.2.2 Property spec for more information.
|
||||
type Properties struct {
|
||||
CorrelationData []byte `json:"cd"`
|
||||
SubscriptionIdentifier []int `json:"si"`
|
||||
AuthenticationData []byte `json:"ad"`
|
||||
User []UserProperty `json:"user"`
|
||||
ContentType string `json:"ct"`
|
||||
ResponseTopic string `json:"rt"`
|
||||
AssignedClientID string `json:"aci"`
|
||||
AuthenticationMethod string `json:"am"`
|
||||
ResponseInfo string `json:"ri"`
|
||||
ServerReference string `json:"sr"`
|
||||
ReasonString string `json:"rs"`
|
||||
MessageExpiryInterval uint32 `json:"me"`
|
||||
SessionExpiryInterval uint32 `json:"sei"`
|
||||
WillDelayInterval uint32 `json:"wdi"`
|
||||
MaximumPacketSize uint32 `json:"mps"`
|
||||
ServerKeepAlive uint16 `json:"ska"`
|
||||
ReceiveMaximum uint16 `json:"rm"`
|
||||
TopicAliasMaximum uint16 `json:"tam"`
|
||||
TopicAlias uint16 `json:"ta"`
|
||||
PayloadFormat byte `json:"pf"`
|
||||
PayloadFormatFlag bool `json:"fpf"`
|
||||
SessionExpiryIntervalFlag bool `json:"fsei"`
|
||||
ServerKeepAliveFlag bool `json:"fska"`
|
||||
RequestProblemInfo byte `json:"rpi"`
|
||||
RequestProblemInfoFlag bool `json:"frpi"`
|
||||
RequestResponseInfo byte `json:"rri"`
|
||||
TopicAliasFlag bool `json:"fta"`
|
||||
MaximumQos byte `json:"mqos"`
|
||||
MaximumQosFlag bool `json:"fmqos"`
|
||||
RetainAvailable byte `json:"ra"`
|
||||
RetainAvailableFlag bool `json:"fra"`
|
||||
WildcardSubAvailable byte `json:"wsa"`
|
||||
WildcardSubAvailableFlag bool `json:"fwsa"`
|
||||
SubIDAvailable byte `json:"sida"`
|
||||
SubIDAvailableFlag bool `json:"fsida"`
|
||||
SharedSubAvailable byte `json:"ssa"`
|
||||
SharedSubAvailableFlag bool `json:"fssa"`
|
||||
}
|
||||
|
||||
// Copy creates a new Properties struct with copies of the values.
|
||||
func (p *Properties) Copy(allowTransfer bool) Properties {
|
||||
pr := Properties{
|
||||
PayloadFormat: p.PayloadFormat, // [MQTT-3.3.2-4]
|
||||
PayloadFormatFlag: p.PayloadFormatFlag,
|
||||
MessageExpiryInterval: p.MessageExpiryInterval,
|
||||
ContentType: p.ContentType, // [MQTT-3.3.2-20]
|
||||
ResponseTopic: p.ResponseTopic, // [MQTT-3.3.2-15]
|
||||
SessionExpiryInterval: p.SessionExpiryInterval,
|
||||
SessionExpiryIntervalFlag: p.SessionExpiryIntervalFlag,
|
||||
AssignedClientID: p.AssignedClientID,
|
||||
ServerKeepAlive: p.ServerKeepAlive,
|
||||
ServerKeepAliveFlag: p.ServerKeepAliveFlag,
|
||||
AuthenticationMethod: p.AuthenticationMethod,
|
||||
RequestProblemInfo: p.RequestProblemInfo,
|
||||
RequestProblemInfoFlag: p.RequestProblemInfoFlag,
|
||||
WillDelayInterval: p.WillDelayInterval,
|
||||
RequestResponseInfo: p.RequestResponseInfo,
|
||||
ResponseInfo: p.ResponseInfo,
|
||||
ServerReference: p.ServerReference,
|
||||
ReasonString: p.ReasonString,
|
||||
ReceiveMaximum: p.ReceiveMaximum,
|
||||
TopicAliasMaximum: p.TopicAliasMaximum,
|
||||
TopicAlias: 0, // NB; do not copy topic alias [MQTT-3.3.2-7] + we do not send to clients (currently) [MQTT-3.1.2-26] [MQTT-3.1.2-27]
|
||||
MaximumQos: p.MaximumQos,
|
||||
MaximumQosFlag: p.MaximumQosFlag,
|
||||
RetainAvailable: p.RetainAvailable,
|
||||
RetainAvailableFlag: p.RetainAvailableFlag,
|
||||
MaximumPacketSize: p.MaximumPacketSize,
|
||||
WildcardSubAvailable: p.WildcardSubAvailable,
|
||||
WildcardSubAvailableFlag: p.WildcardSubAvailableFlag,
|
||||
SubIDAvailable: p.SubIDAvailable,
|
||||
SubIDAvailableFlag: p.SubIDAvailableFlag,
|
||||
SharedSubAvailable: p.SharedSubAvailable,
|
||||
SharedSubAvailableFlag: p.SharedSubAvailableFlag,
|
||||
}
|
||||
|
||||
if allowTransfer {
|
||||
pr.TopicAlias = p.TopicAlias
|
||||
pr.TopicAliasFlag = p.TopicAliasFlag
|
||||
}
|
||||
|
||||
if len(p.CorrelationData) > 0 {
|
||||
pr.CorrelationData = append([]byte{}, p.CorrelationData...) // [MQTT-3.3.2-16]
|
||||
}
|
||||
|
||||
if len(p.SubscriptionIdentifier) > 0 {
|
||||
pr.SubscriptionIdentifier = append([]int{}, p.SubscriptionIdentifier...)
|
||||
}
|
||||
|
||||
if len(p.AuthenticationData) > 0 {
|
||||
pr.AuthenticationData = append([]byte{}, p.AuthenticationData...)
|
||||
}
|
||||
|
||||
if len(p.User) > 0 {
|
||||
pr.User = []UserProperty{}
|
||||
for _, v := range p.User {
|
||||
pr.User = append(pr.User, UserProperty{ // [MQTT-3.3.2-17]
|
||||
Key: v.Key,
|
||||
Val: v.Val,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return pr
|
||||
}
|
||||
|
||||
// canEncode returns true if the property type is valid for the packet type.
|
||||
func (p *Properties) canEncode(pkt byte, k byte) bool {
|
||||
return validPacketProperties[k][pkt] == 1
|
||||
}
|
||||
|
||||
// Encode encodes properties into a bytes buffer.
|
||||
func (p *Properties) Encode(pk *Packet, b *bytes.Buffer, n int) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
pkt := pk.FixedHeader.Type
|
||||
|
||||
if p.canEncode(pkt, PropPayloadFormat) && p.PayloadFormatFlag {
|
||||
buf.WriteByte(PropPayloadFormat)
|
||||
buf.WriteByte(p.PayloadFormat)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMessageExpiryInterval) && p.MessageExpiryInterval > 0 {
|
||||
buf.WriteByte(PropMessageExpiryInterval)
|
||||
buf.Write(encodeUint32(p.MessageExpiryInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropContentType) && p.ContentType != "" {
|
||||
buf.WriteByte(PropContentType)
|
||||
buf.Write(encodeString(p.ContentType)) // [MQTT-3.3.2-19]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseTopic) && // [MQTT-3.3.2-14]
|
||||
p.ResponseTopic != "" && !strings.ContainsAny(p.ResponseTopic, "+#") { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseTopic)
|
||||
buf.Write(encodeString(p.ResponseTopic)) // [MQTT-3.3.2-13]
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropCorrelationData) && len(p.CorrelationData) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropCorrelationData)
|
||||
buf.Write(encodeBytes(p.CorrelationData))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSubscriptionIdentifier) && len(p.SubscriptionIdentifier) > 0 {
|
||||
for _, v := range p.SubscriptionIdentifier {
|
||||
if v > 0 {
|
||||
buf.WriteByte(PropSubscriptionIdentifier)
|
||||
encodeLength(&buf, int64(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSessionExpiryInterval) && p.SessionExpiryIntervalFlag { // [MQTT-3.14.2-2]
|
||||
buf.WriteByte(PropSessionExpiryInterval)
|
||||
buf.Write(encodeUint32(p.SessionExpiryInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAssignedClientID) && p.AssignedClientID != "" {
|
||||
buf.WriteByte(PropAssignedClientID)
|
||||
buf.Write(encodeString(p.AssignedClientID))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropServerKeepAlive) && p.ServerKeepAliveFlag {
|
||||
buf.WriteByte(PropServerKeepAlive)
|
||||
buf.Write(encodeUint16(p.ServerKeepAlive))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAuthenticationMethod) && p.AuthenticationMethod != "" {
|
||||
buf.WriteByte(PropAuthenticationMethod)
|
||||
buf.Write(encodeString(p.AuthenticationMethod))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropAuthenticationData) && len(p.AuthenticationData) > 0 {
|
||||
buf.WriteByte(PropAuthenticationData)
|
||||
buf.Write(encodeBytes(p.AuthenticationData))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRequestProblemInfo) && p.RequestProblemInfoFlag {
|
||||
buf.WriteByte(PropRequestProblemInfo)
|
||||
buf.WriteByte(p.RequestProblemInfo)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropWillDelayInterval) && p.WillDelayInterval > 0 {
|
||||
buf.WriteByte(PropWillDelayInterval)
|
||||
buf.Write(encodeUint32(p.WillDelayInterval))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRequestResponseInfo) && p.RequestResponseInfo > 0 {
|
||||
buf.WriteByte(PropRequestResponseInfo)
|
||||
buf.WriteByte(p.RequestResponseInfo)
|
||||
}
|
||||
|
||||
if pk.Mods.AllowResponseInfo && p.canEncode(pkt, PropResponseInfo) && len(p.ResponseInfo) > 0 { // [MQTT-3.1.2-28]
|
||||
buf.WriteByte(PropResponseInfo)
|
||||
buf.Write(encodeString(p.ResponseInfo))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropServerReference) && len(p.ServerReference) > 0 {
|
||||
buf.WriteByte(PropServerReference)
|
||||
buf.Write(encodeString(p.ServerReference))
|
||||
}
|
||||
|
||||
// [MQTT-3.2.2-19] [MQTT-3.14.2-3] [MQTT-3.4.2-2] [MQTT-3.5.2-2]
|
||||
// [MQTT-3.6.2-2] [MQTT-3.9.2-1] [MQTT-3.11.2-1] [MQTT-3.15.2-2]
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropReasonString) && p.ReasonString != "" {
|
||||
b := encodeString(p.ReasonString)
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+len(b)+1) < pk.Mods.MaxSize {
|
||||
buf.WriteByte(PropReasonString)
|
||||
buf.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropReceiveMaximum) && p.ReceiveMaximum > 0 {
|
||||
buf.WriteByte(PropReceiveMaximum)
|
||||
buf.Write(encodeUint16(p.ReceiveMaximum))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropTopicAliasMaximum) && p.TopicAliasMaximum > 0 {
|
||||
buf.WriteByte(PropTopicAliasMaximum)
|
||||
buf.Write(encodeUint16(p.TopicAliasMaximum))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropTopicAlias) && p.TopicAliasFlag && p.TopicAlias > 0 { // [MQTT-3.3.2-8]
|
||||
buf.WriteByte(PropTopicAlias)
|
||||
buf.Write(encodeUint16(p.TopicAlias))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMaximumQos) && p.MaximumQosFlag && p.MaximumQos < 2 {
|
||||
buf.WriteByte(PropMaximumQos)
|
||||
buf.WriteByte(p.MaximumQos)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropRetainAvailable) && p.RetainAvailableFlag {
|
||||
buf.WriteByte(PropRetainAvailable)
|
||||
buf.WriteByte(p.RetainAvailable)
|
||||
}
|
||||
|
||||
if !pk.Mods.DisallowProblemInfo && p.canEncode(pkt, PropUser) {
|
||||
pb := bytes.NewBuffer([]byte{})
|
||||
for _, v := range p.User {
|
||||
pb.WriteByte(PropUser)
|
||||
pb.Write(encodeString(v.Key))
|
||||
pb.Write(encodeString(v.Val))
|
||||
}
|
||||
// [MQTT-3.2.2-20] [MQTT-3.14.2-4] [MQTT-3.4.2-3] [MQTT-3.5.2-3]
|
||||
// [MQTT-3.6.2-3] [MQTT-3.9.2-2] [MQTT-3.11.2-2] [MQTT-3.15.2-3]
|
||||
if pk.Mods.MaxSize == 0 || uint32(n+pb.Len()+1) < pk.Mods.MaxSize {
|
||||
buf.Write(pb.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropMaximumPacketSize) && p.MaximumPacketSize > 0 {
|
||||
buf.WriteByte(PropMaximumPacketSize)
|
||||
buf.Write(encodeUint32(p.MaximumPacketSize))
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropWildcardSubAvailable) && p.WildcardSubAvailableFlag {
|
||||
buf.WriteByte(PropWildcardSubAvailable)
|
||||
buf.WriteByte(p.WildcardSubAvailable)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSubIDAvailable) && p.SubIDAvailableFlag {
|
||||
buf.WriteByte(PropSubIDAvailable)
|
||||
buf.WriteByte(p.SubIDAvailable)
|
||||
}
|
||||
|
||||
if p.canEncode(pkt, PropSharedSubAvailable) && p.SharedSubAvailableFlag {
|
||||
buf.WriteByte(PropSharedSubAvailable)
|
||||
buf.WriteByte(p.SharedSubAvailable)
|
||||
}
|
||||
|
||||
encodeLength(b, int64(buf.Len()))
|
||||
buf.WriteTo(b) // [MQTT-3.1.3-10]
|
||||
}
|
||||
|
||||
// Decode decodes property bytes into a properties struct.
|
||||
func (p *Properties) Decode(pk byte, b *bytes.Buffer) (n int, err error) {
|
||||
if p == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var bu int
|
||||
n, bu, err = DecodeLength(b)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return n + bu, nil
|
||||
}
|
||||
|
||||
bt := b.Bytes()
|
||||
var k byte
|
||||
for offset := 0; offset < n; {
|
||||
k, offset, err = decodeByte(bt, offset)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
|
||||
if _, ok := validPacketProperties[k][pk]; !ok {
|
||||
return n + bu, fmt.Errorf("property type %v not valid for packet type %v: %w", k, pk, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
switch k {
|
||||
case PropPayloadFormat:
|
||||
p.PayloadFormat, offset, err = decodeByte(bt, offset)
|
||||
p.PayloadFormatFlag = true
|
||||
case PropMessageExpiryInterval:
|
||||
p.MessageExpiryInterval, offset, err = decodeUint32(bt, offset)
|
||||
case PropContentType:
|
||||
p.ContentType, offset, err = decodeString(bt, offset)
|
||||
case PropResponseTopic:
|
||||
p.ResponseTopic, offset, err = decodeString(bt, offset)
|
||||
case PropCorrelationData:
|
||||
p.CorrelationData, offset, err = decodeBytes(bt, offset)
|
||||
case PropSubscriptionIdentifier:
|
||||
if p.SubscriptionIdentifier == nil {
|
||||
p.SubscriptionIdentifier = []int{}
|
||||
}
|
||||
|
||||
n, bu, err := DecodeLength(bytes.NewBuffer(bt[offset:]))
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
p.SubscriptionIdentifier = append(p.SubscriptionIdentifier, n)
|
||||
offset += bu
|
||||
case PropSessionExpiryInterval:
|
||||
p.SessionExpiryInterval, offset, err = decodeUint32(bt, offset)
|
||||
p.SessionExpiryIntervalFlag = true
|
||||
case PropAssignedClientID:
|
||||
p.AssignedClientID, offset, err = decodeString(bt, offset)
|
||||
case PropServerKeepAlive:
|
||||
p.ServerKeepAlive, offset, err = decodeUint16(bt, offset)
|
||||
p.ServerKeepAliveFlag = true
|
||||
case PropAuthenticationMethod:
|
||||
p.AuthenticationMethod, offset, err = decodeString(bt, offset)
|
||||
case PropAuthenticationData:
|
||||
p.AuthenticationData, offset, err = decodeBytes(bt, offset)
|
||||
case PropRequestProblemInfo:
|
||||
p.RequestProblemInfo, offset, err = decodeByte(bt, offset)
|
||||
p.RequestProblemInfoFlag = true
|
||||
case PropWillDelayInterval:
|
||||
p.WillDelayInterval, offset, err = decodeUint32(bt, offset)
|
||||
case PropRequestResponseInfo:
|
||||
p.RequestResponseInfo, offset, err = decodeByte(bt, offset)
|
||||
case PropResponseInfo:
|
||||
p.ResponseInfo, offset, err = decodeString(bt, offset)
|
||||
case PropServerReference:
|
||||
p.ServerReference, offset, err = decodeString(bt, offset)
|
||||
case PropReasonString:
|
||||
p.ReasonString, offset, err = decodeString(bt, offset)
|
||||
case PropReceiveMaximum:
|
||||
p.ReceiveMaximum, offset, err = decodeUint16(bt, offset)
|
||||
case PropTopicAliasMaximum:
|
||||
p.TopicAliasMaximum, offset, err = decodeUint16(bt, offset)
|
||||
case PropTopicAlias:
|
||||
p.TopicAlias, offset, err = decodeUint16(bt, offset)
|
||||
p.TopicAliasFlag = true
|
||||
case PropMaximumQos:
|
||||
p.MaximumQos, offset, err = decodeByte(bt, offset)
|
||||
p.MaximumQosFlag = true
|
||||
case PropRetainAvailable:
|
||||
p.RetainAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.RetainAvailableFlag = true
|
||||
case PropUser:
|
||||
var k, v string
|
||||
k, offset, err = decodeString(bt, offset)
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
v, offset, err = decodeString(bt, offset)
|
||||
p.User = append(p.User, UserProperty{Key: k, Val: v})
|
||||
case PropMaximumPacketSize:
|
||||
p.MaximumPacketSize, offset, err = decodeUint32(bt, offset)
|
||||
case PropWildcardSubAvailable:
|
||||
p.WildcardSubAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.WildcardSubAvailableFlag = true
|
||||
case PropSubIDAvailable:
|
||||
p.SubIDAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.SubIDAvailableFlag = true
|
||||
case PropSharedSubAvailable:
|
||||
p.SharedSubAvailable, offset, err = decodeByte(bt, offset)
|
||||
p.SharedSubAvailableFlag = true
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return n + bu, err
|
||||
}
|
||||
}
|
||||
|
||||
return n + bu, nil
|
||||
}
|
||||
333
backend/services/mochi/packets/properties_test.go
Normal file
333
backend/services/mochi/packets/properties_test.go
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
propertiesStruct = Properties{
|
||||
PayloadFormat: byte(1), // UTF-8 Format
|
||||
PayloadFormatFlag: true,
|
||||
MessageExpiryInterval: uint32(2),
|
||||
ContentType: "text/plain",
|
||||
ResponseTopic: "a/b/c",
|
||||
CorrelationData: []byte("data"),
|
||||
SubscriptionIdentifier: []int{322122},
|
||||
SessionExpiryInterval: uint32(120),
|
||||
SessionExpiryIntervalFlag: true,
|
||||
AssignedClientID: "mochi-v5",
|
||||
ServerKeepAlive: uint16(20),
|
||||
ServerKeepAliveFlag: true,
|
||||
AuthenticationMethod: "SHA-1",
|
||||
AuthenticationData: []byte("auth-data"),
|
||||
RequestProblemInfo: byte(1),
|
||||
RequestProblemInfoFlag: true,
|
||||
WillDelayInterval: uint32(600),
|
||||
RequestResponseInfo: byte(1),
|
||||
ResponseInfo: "response",
|
||||
ServerReference: "mochi-2",
|
||||
ReasonString: "reason",
|
||||
ReceiveMaximum: uint16(500),
|
||||
TopicAliasMaximum: uint16(999),
|
||||
TopicAlias: uint16(3),
|
||||
TopicAliasFlag: true,
|
||||
MaximumQos: byte(1),
|
||||
MaximumQosFlag: true,
|
||||
RetainAvailable: byte(1),
|
||||
RetainAvailableFlag: true,
|
||||
User: []UserProperty{
|
||||
{
|
||||
Key: "hello",
|
||||
Val: "世界",
|
||||
},
|
||||
{
|
||||
Key: "key2",
|
||||
Val: "value2",
|
||||
},
|
||||
},
|
||||
MaximumPacketSize: uint32(32000),
|
||||
WildcardSubAvailable: byte(1),
|
||||
WildcardSubAvailableFlag: true,
|
||||
SubIDAvailable: byte(1),
|
||||
SubIDAvailableFlag: true,
|
||||
SharedSubAvailable: byte(1),
|
||||
SharedSubAvailableFlag: true,
|
||||
}
|
||||
|
||||
propertiesBytes = []byte{
|
||||
172, 1, // VBI
|
||||
|
||||
// Payload Format (1) (vbi:2)
|
||||
1, 1,
|
||||
|
||||
// Message Expiry (2) (vbi:7)
|
||||
2, 0, 0, 0, 2,
|
||||
|
||||
// Content Type (3) (vbi:20)
|
||||
3,
|
||||
0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n',
|
||||
|
||||
// Response Topic (8) (vbi:28)
|
||||
8,
|
||||
0, 5, 'a', '/', 'b', '/', 'c',
|
||||
|
||||
// Correlations Data (9) (vbi:35)
|
||||
9,
|
||||
0, 4, 'd', 'a', 't', 'a',
|
||||
|
||||
// Subscription Identifier (11) (vbi:39)
|
||||
11,
|
||||
202, 212, 19,
|
||||
|
||||
// Session Expiry Interval (17) (vbi:43)
|
||||
17,
|
||||
0, 0, 0, 120,
|
||||
|
||||
// Assigned Client ID (18) (vbi:55)
|
||||
18,
|
||||
0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5',
|
||||
|
||||
// Server Keep Alive (19) (vbi:58)
|
||||
19,
|
||||
0, 20,
|
||||
|
||||
// Authentication Method (21) (vbi:66)
|
||||
21,
|
||||
0, 5, 'S', 'H', 'A', '-', '1',
|
||||
|
||||
// Authentication Data (22) (vbi:78)
|
||||
22,
|
||||
0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a',
|
||||
|
||||
// Request Problem Info (23) (vbi:80)
|
||||
23, 1,
|
||||
|
||||
// Will Delay Interval (24) (vbi:85)
|
||||
24,
|
||||
0, 0, 2, 88,
|
||||
|
||||
// Request Response Info (25) (vbi:87)
|
||||
25, 1,
|
||||
|
||||
// Response Info (26) (vbi:98)
|
||||
26,
|
||||
0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e',
|
||||
|
||||
// Server Reference (28) (vbi:108)
|
||||
28,
|
||||
0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2',
|
||||
|
||||
// Reason String (31) (vbi:117)
|
||||
31,
|
||||
0, 6, 'r', 'e', 'a', 's', 'o', 'n',
|
||||
|
||||
// Receive Maximum (33) (vbi:120)
|
||||
33,
|
||||
1, 244,
|
||||
|
||||
// Topic Alias Maximum (34) (vbi:123)
|
||||
34,
|
||||
3, 231,
|
||||
|
||||
// Topic Alias (35) (vbi:126)
|
||||
35,
|
||||
0, 3,
|
||||
|
||||
// Maximum Qos (36) (vbi:128)
|
||||
36, 1,
|
||||
|
||||
// Retain Available (37) (vbi: 130)
|
||||
37, 1,
|
||||
|
||||
// User Properties (38) (vbi:161)
|
||||
38,
|
||||
0, 5, 'h', 'e', 'l', 'l', 'o',
|
||||
0, 6, 228, 184, 150, 231, 149, 140,
|
||||
38,
|
||||
0, 4, 'k', 'e', 'y', '2',
|
||||
0, 6, 'v', 'a', 'l', 'u', 'e', '2',
|
||||
|
||||
// Maximum Packet Size (39) (vbi:166)
|
||||
39,
|
||||
0, 0, 125, 0,
|
||||
|
||||
// Wildcard Subscriptions Available (40) (vbi:168)
|
||||
40, 1,
|
||||
|
||||
// Subscription ID Available (41) (vbi:170)
|
||||
41, 1,
|
||||
|
||||
// Shared Subscriptions Available (42) (vbi:172)
|
||||
42, 1,
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
validPacketProperties[PropPayloadFormat][Reserved] = 1
|
||||
validPacketProperties[PropMessageExpiryInterval][Reserved] = 1
|
||||
validPacketProperties[PropContentType][Reserved] = 1
|
||||
validPacketProperties[PropResponseTopic][Reserved] = 1
|
||||
validPacketProperties[PropCorrelationData][Reserved] = 1
|
||||
validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1
|
||||
validPacketProperties[PropSessionExpiryInterval][Reserved] = 1
|
||||
validPacketProperties[PropAssignedClientID][Reserved] = 1
|
||||
validPacketProperties[PropServerKeepAlive][Reserved] = 1
|
||||
validPacketProperties[PropAuthenticationMethod][Reserved] = 1
|
||||
validPacketProperties[PropAuthenticationData][Reserved] = 1
|
||||
validPacketProperties[PropRequestProblemInfo][Reserved] = 1
|
||||
validPacketProperties[PropWillDelayInterval][Reserved] = 1
|
||||
validPacketProperties[PropRequestResponseInfo][Reserved] = 1
|
||||
validPacketProperties[PropResponseInfo][Reserved] = 1
|
||||
validPacketProperties[PropServerReference][Reserved] = 1
|
||||
validPacketProperties[PropReasonString][Reserved] = 1
|
||||
validPacketProperties[PropReceiveMaximum][Reserved] = 1
|
||||
validPacketProperties[PropTopicAliasMaximum][Reserved] = 1
|
||||
validPacketProperties[PropTopicAlias][Reserved] = 1
|
||||
validPacketProperties[PropMaximumQos][Reserved] = 1
|
||||
validPacketProperties[PropRetainAvailable][Reserved] = 1
|
||||
validPacketProperties[PropUser][Reserved] = 1
|
||||
validPacketProperties[PropMaximumPacketSize][Reserved] = 1
|
||||
validPacketProperties[PropWildcardSubAvailable][Reserved] = 1
|
||||
validPacketProperties[PropSubIDAvailable][Reserved] = 1
|
||||
validPacketProperties[PropSharedSubAvailable][Reserved] = 1
|
||||
}
|
||||
|
||||
func TestEncodeProperties(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
require.Equal(t, propertiesBytes, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowProblemInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{DisallowProblemInfo: true}}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5}))
|
||||
require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8}))
|
||||
}
|
||||
|
||||
func TestEncodePropertiesDisallowResponseInfo(t *testing.T) {
|
||||
props := propertiesStruct
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: false}}, b, 0)
|
||||
require.NotEqual(t, propertiesBytes, b.Bytes())
|
||||
require.NotContains(t, b.Bytes(), []byte{8, 0, 5})
|
||||
require.NotContains(t, b.Bytes(), []byte{9, 0, 4})
|
||||
}
|
||||
|
||||
func TestEncodePropertiesNil(t *testing.T) {
|
||||
type tmp struct {
|
||||
p *Properties
|
||||
}
|
||||
|
||||
pr := tmp{}
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
pr.p.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}}, b, 0)
|
||||
require.Equal(t, []byte{}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestEncodeZeroProperties(t *testing.T) {
|
||||
// [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero.
|
||||
props := new(Properties)
|
||||
b := bytes.NewBuffer([]byte{})
|
||||
props.Encode(&Packet{FixedHeader: FixedHeader{Type: Reserved}, Mods: Mods{AllowResponseInfo: true}}, b, 0)
|
||||
require.Equal(t, []byte{0x00}, b.Bytes())
|
||||
}
|
||||
|
||||
func TestDecodeProperties(t *testing.T) {
|
||||
b := bytes.NewBuffer(propertiesBytes)
|
||||
|
||||
props := new(Properties)
|
||||
n, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 172+2, n)
|
||||
require.EqualValues(t, propertiesStruct, *props)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesNil(t *testing.T) {
|
||||
b := bytes.NewBuffer(propertiesBytes)
|
||||
|
||||
type tmp struct {
|
||||
p *Properties
|
||||
}
|
||||
|
||||
pr := tmp{}
|
||||
n, err := pr.p.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadInitialVBI(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, ErrMalformedVariableByteInteger, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesZeroLengthVBI(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{0})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, props, new(Properties))
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadKeyByte(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{64, 1})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesInvalidForPacket(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{1, 99})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesGeneralFailure(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadSubscriptionID(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodePropertiesBadUserProps(t *testing.T) {
|
||||
b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255})
|
||||
props := new(Properties)
|
||||
_, err := props.Decode(Reserved, b)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCopyProperties(t *testing.T) {
|
||||
require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true))
|
||||
}
|
||||
|
||||
func TestCopyPropertiesNoTransfer(t *testing.T) {
|
||||
pkA := propertiesStruct
|
||||
pkB := pkA.Copy(false)
|
||||
|
||||
// Properties which should never be transferred from one connection to another
|
||||
require.Equal(t, uint16(0), pkB.TopicAlias)
|
||||
}
|
||||
3853
backend/services/mochi/packets/tpackets.go
Normal file
3853
backend/services/mochi/packets/tpackets.go
Normal file
File diff suppressed because it is too large
Load Diff
33
backend/services/mochi/packets/tpackets_test.go
Normal file
33
backend/services/mochi/packets/tpackets_test.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package packets
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func encodeTestOK(wanted TPacketCase) bool {
|
||||
if wanted.RawBytes == nil {
|
||||
return false
|
||||
}
|
||||
if wanted.Group != "" && wanted.Group != "encode" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func decodeTestOK(wanted TPacketCase) bool {
|
||||
if wanted.Group != "" && wanted.Group != "decode" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestTPacketCaseGet(t *testing.T) {
|
||||
require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311))
|
||||
require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128)))
|
||||
}
|
||||
1529
backend/services/mochi/server.go
Normal file
1529
backend/services/mochi/server.go
Normal file
File diff suppressed because it is too large
Load Diff
2857
backend/services/mochi/server_test.go
Normal file
2857
backend/services/mochi/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
61
backend/services/mochi/system/system.go
Normal file
61
backend/services/mochi/system/system.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package system
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Info contains atomic counters and values for various server statistics
|
||||
// commonly found in $SYS topics (and others).
|
||||
// based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics
|
||||
type Info struct {
|
||||
Version string `json:"version"` // the current version of the server
|
||||
Started int64 `json:"started"` // the time the server started in unix seconds
|
||||
Time int64 `json:"time"` // current time on the server
|
||||
Uptime int64 `json:"uptime"` // the number of seconds the server has been online
|
||||
BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started
|
||||
BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started
|
||||
ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients
|
||||
ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected
|
||||
ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected
|
||||
ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered
|
||||
MessagesReceived int64 `json:"messages_received"` // total number of publish messages received
|
||||
MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent
|
||||
MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber
|
||||
Retained int64 `json:"retained"` // total number of retained messages active on the broker
|
||||
Inflight int64 `json:"inflight"` // the number of messages currently in-flight
|
||||
InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped
|
||||
Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker
|
||||
PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received
|
||||
PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started
|
||||
MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated
|
||||
Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity
|
||||
}
|
||||
|
||||
// Clone makes a copy of Info using atomic operation
|
||||
func (i *Info) Clone() *Info {
|
||||
return &Info{
|
||||
Version: i.Version,
|
||||
Started: atomic.LoadInt64(&i.Started),
|
||||
Time: atomic.LoadInt64(&i.Time),
|
||||
Uptime: atomic.LoadInt64(&i.Uptime),
|
||||
BytesReceived: atomic.LoadInt64(&i.BytesReceived),
|
||||
BytesSent: atomic.LoadInt64(&i.BytesSent),
|
||||
ClientsConnected: atomic.LoadInt64(&i.ClientsConnected),
|
||||
ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum),
|
||||
ClientsTotal: atomic.LoadInt64(&i.ClientsTotal),
|
||||
ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected),
|
||||
MessagesReceived: atomic.LoadInt64(&i.MessagesReceived),
|
||||
MessagesSent: atomic.LoadInt64(&i.MessagesSent),
|
||||
MessagesDropped: atomic.LoadInt64(&i.MessagesDropped),
|
||||
Retained: atomic.LoadInt64(&i.Retained),
|
||||
Inflight: atomic.LoadInt64(&i.Inflight),
|
||||
InflightDropped: atomic.LoadInt64(&i.InflightDropped),
|
||||
Subscriptions: atomic.LoadInt64(&i.Subscriptions),
|
||||
PacketsReceived: atomic.LoadInt64(&i.PacketsReceived),
|
||||
PacketsSent: atomic.LoadInt64(&i.PacketsSent),
|
||||
MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc),
|
||||
Threads: atomic.LoadInt64(&i.Threads),
|
||||
}
|
||||
}
|
||||
37
backend/services/mochi/system/system_test.go
Normal file
37
backend/services/mochi/system/system_test.go
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
package system
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
o := &Info{
|
||||
Version: "version",
|
||||
Started: 1,
|
||||
Time: 2,
|
||||
Uptime: 3,
|
||||
BytesReceived: 4,
|
||||
BytesSent: 5,
|
||||
ClientsConnected: 6,
|
||||
ClientsMaximum: 7,
|
||||
ClientsTotal: 8,
|
||||
ClientsDisconnected: 9,
|
||||
MessagesReceived: 10,
|
||||
MessagesSent: 11,
|
||||
MessagesDropped: 20,
|
||||
Retained: 12,
|
||||
Inflight: 13,
|
||||
InflightDropped: 14,
|
||||
Subscriptions: 15,
|
||||
PacketsReceived: 16,
|
||||
PacketsSent: 17,
|
||||
MemoryAlloc: 18,
|
||||
Threads: 19,
|
||||
}
|
||||
|
||||
n := o.Clone()
|
||||
|
||||
require.Equal(t, o, n)
|
||||
}
|
||||
699
backend/services/mochi/topics.go
Normal file
699
backend/services/mochi/topics.go
Normal file
|
|
@ -0,0 +1,699 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
)
|
||||
|
||||
var (
|
||||
SharePrefix = "$SHARE" // the prefix indicating a share topic
|
||||
SysPrefix = "$SYS" // the prefix indicating a system info topic
|
||||
)
|
||||
|
||||
// TopicAliases contains inbound and outbound topic alias registrations.
|
||||
type TopicAliases struct {
|
||||
Inbound *InboundTopicAliases
|
||||
Outbound *OutboundTopicAliases
|
||||
}
|
||||
|
||||
// NewTopicAliases returns an instance of TopicAliases.
|
||||
func NewTopicAliases(topicAliasMaximum uint16) TopicAliases {
|
||||
return TopicAliases{
|
||||
Inbound: NewInboundTopicAliases(topicAliasMaximum),
|
||||
Outbound: NewOutboundTopicAliases(topicAliasMaximum),
|
||||
}
|
||||
}
|
||||
|
||||
// NewInboundTopicAliases returns a pointer to InboundTopicAliases.
|
||||
func NewInboundTopicAliases(topicAliasMaximum uint16) *InboundTopicAliases {
|
||||
return &InboundTopicAliases{
|
||||
maximum: topicAliasMaximum,
|
||||
internal: map[uint16]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// InboundTopicAliases contains a map of topic aliases received from the client.
|
||||
type InboundTopicAliases struct {
|
||||
internal map[uint16]string
|
||||
sync.RWMutex
|
||||
maximum uint16
|
||||
}
|
||||
|
||||
// Set sets a new alias for a specific topic.
|
||||
func (a *InboundTopicAliases) Set(id uint16, topic string) string {
|
||||
a.Lock()
|
||||
defer a.Unlock()
|
||||
|
||||
if a.maximum == 0 {
|
||||
return topic // ?
|
||||
}
|
||||
|
||||
if existing, ok := a.internal[id]; ok && topic == "" {
|
||||
return existing
|
||||
}
|
||||
|
||||
a.internal[id] = topic
|
||||
return topic
|
||||
}
|
||||
|
||||
// OutboundTopicAliases contains a map of topic aliases sent from the broker to the client.
|
||||
type OutboundTopicAliases struct {
|
||||
internal map[string]uint16
|
||||
sync.RWMutex
|
||||
cursor uint32
|
||||
maximum uint16
|
||||
}
|
||||
|
||||
// NewOutboundTopicAliases returns a pointer to OutboundTopicAliases.
|
||||
func NewOutboundTopicAliases(topicAliasMaximum uint16) *OutboundTopicAliases {
|
||||
return &OutboundTopicAliases{
|
||||
maximum: topicAliasMaximum,
|
||||
internal: map[string]uint16{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a new topic alias for a topic and returns the alias value, and a boolean
|
||||
// indicating if the alias already existed.
|
||||
func (a *OutboundTopicAliases) Set(topic string) (uint16, bool) {
|
||||
a.Lock()
|
||||
defer a.Unlock()
|
||||
|
||||
if a.maximum == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if i, ok := a.internal[topic]; ok {
|
||||
return i, true
|
||||
}
|
||||
|
||||
i := atomic.LoadUint32(&a.cursor)
|
||||
if i+1 > uint32(a.maximum) {
|
||||
// if i+1 > math.MaxUint16 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
a.internal[topic] = uint16(i) + 1
|
||||
atomic.StoreUint32(&a.cursor, i+1)
|
||||
return uint16(i) + 1, false
|
||||
}
|
||||
|
||||
// SharedSubscriptions contains a map of subscriptions to a shared filter,
|
||||
// keyed on share group then client id.
|
||||
type SharedSubscriptions struct {
|
||||
internal map[string]map[string]packets.Subscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSharedSubscriptions returns a new instance of Subscriptions.
|
||||
func NewSharedSubscriptions() *SharedSubscriptions {
|
||||
return &SharedSubscriptions{
|
||||
internal: map[string]map[string]packets.Subscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new shared subscription for a group and client id pair.
|
||||
func (s *SharedSubscriptions) Add(group, id string, val packets.Subscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
if _, ok := s.internal[group]; !ok {
|
||||
s.internal[group] = map[string]packets.Subscription{}
|
||||
}
|
||||
s.internal[group][id] = val
|
||||
}
|
||||
|
||||
// Delete deletes a client id from a shared subscription group.
|
||||
func (s *SharedSubscriptions) Delete(group, id string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal[group], id)
|
||||
if len(s.internal[group]) == 0 {
|
||||
delete(s.internal, group)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the subscription properties for a client id in a share group, if one exists.
|
||||
func (s *SharedSubscriptions) Get(group, id string) (val packets.Subscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
if _, ok := s.internal[group]; !ok {
|
||||
return val, ok
|
||||
}
|
||||
|
||||
val, ok = s.internal[group][id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// GroupLen returns the number of groups subscribed to the filter.
|
||||
func (s *SharedSubscriptions) GroupLen() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Len returns the total number of shared subscriptions to a filter across all groups.
|
||||
func (s *SharedSubscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
n := 0
|
||||
for _, group := range s.internal {
|
||||
n += len(group)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// GetAll returns all shared subscription groups and their subscriptions.
|
||||
func (s *SharedSubscriptions) GetAll() map[string]map[string]packets.Subscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[string]map[string]packets.Subscription{}
|
||||
for group, subs := range s.internal {
|
||||
if _, ok := m[group]; !ok {
|
||||
m[group] = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for id, sub := range subs {
|
||||
m[group][id] = sub
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Subscriptions is a map of subscriptions keyed on client.
|
||||
type Subscriptions struct {
|
||||
internal map[string]packets.Subscription
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSubscriptions returns a new instance of Subscriptions.
|
||||
func NewSubscriptions() *Subscriptions {
|
||||
return &Subscriptions{
|
||||
internal: map[string]packets.Subscription{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new subscription for a client. ID can be a filter in the
|
||||
// case this map is client state, or a client id if particle state.
|
||||
func (s *Subscriptions) Add(id string, val packets.Subscription) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.internal[id] = val
|
||||
}
|
||||
|
||||
// GetAll returns all subscriptions.
|
||||
func (s *Subscriptions) GetAll() map[string]packets.Subscription {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
m := map[string]packets.Subscription{}
|
||||
for k, v := range s.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Get returns a subscriptions for a specific client or filter id.
|
||||
func (s *Subscriptions) Get(id string) (val packets.Subscription, ok bool) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val, ok = s.internal[id]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Len returns the number of subscriptions.
|
||||
func (s *Subscriptions) Len() int {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
val := len(s.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// Delete removes a subscription by client or filter id.
|
||||
func (s *Subscriptions) Delete(id string) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
delete(s.internal, id)
|
||||
}
|
||||
|
||||
// ClientSubscriptions is a map of aggregated subscriptions for a client.
|
||||
type ClientSubscriptions map[string]packets.Subscription
|
||||
|
||||
// Subscribers contains the shared and non-shared subscribers matching a topic.
|
||||
type Subscribers struct {
|
||||
Shared map[string]map[string]packets.Subscription
|
||||
SharedSelected map[string]packets.Subscription
|
||||
Subscriptions map[string]packets.Subscription
|
||||
}
|
||||
|
||||
// SelectShared returns one subscriber for each shared subscription group.
|
||||
func (s *Subscribers) SelectShared() {
|
||||
s.SharedSelected = map[string]packets.Subscription{}
|
||||
for _, subs := range s.Shared {
|
||||
for client, sub := range subs {
|
||||
cls, ok := s.SharedSelected[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
s.SharedSelected[client] = cls.Merge(sub)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSharedSelected merges the selected subscribers for a shared subscription group
|
||||
// and the non-shared subscribers, to ensure that no subscriber gets multiple messages
|
||||
// due to have both types of subscription matching the same filter.
|
||||
func (s *Subscribers) MergeSharedSelected() {
|
||||
for client, sub := range s.SharedSelected {
|
||||
cls, ok := s.Subscriptions[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
s.Subscriptions[client] = cls.Merge(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// TopicsIndex is a prefix/trie tree containing topic subscribers and retained messages.
|
||||
type TopicsIndex struct {
|
||||
Retained *packets.Packets
|
||||
root *particle // a leaf containing a message and more leaves.
|
||||
}
|
||||
|
||||
// NewTopicsIndex returns a pointer to a new instance of Index.
|
||||
func NewTopicsIndex() *TopicsIndex {
|
||||
return &TopicsIndex{
|
||||
Retained: packets.NewPackets(),
|
||||
root: &particle{
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new subscription for a client to a topic filter, returning
|
||||
// true if the subscription was new.
|
||||
func (x *TopicsIndex) Subscribe(client string, subscription packets.Subscription) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var existed bool
|
||||
prefix, _ := isolateParticle(subscription.Filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
group, _ := isolateParticle(subscription.Filter, 1)
|
||||
n := x.set(subscription.Filter, 2)
|
||||
_, existed = n.shared.Get(group, client)
|
||||
n.shared.Add(group, client, subscription)
|
||||
} else {
|
||||
n := x.set(subscription.Filter, 0)
|
||||
_, existed = n.subscriptions.Get(client)
|
||||
n.subscriptions.Add(client, subscription)
|
||||
}
|
||||
|
||||
return !existed
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscription filter for a client, returning true if the
|
||||
// subscription existed.
|
||||
func (x *TopicsIndex) Unsubscribe(filter, client string) bool {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
var d int
|
||||
if strings.HasPrefix(filter, SharePrefix) {
|
||||
d = 2
|
||||
}
|
||||
|
||||
particle := x.seek(filter, d)
|
||||
if particle == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
if strings.EqualFold(prefix, SharePrefix) {
|
||||
group, _ := isolateParticle(filter, 1)
|
||||
particle.shared.Delete(group, client)
|
||||
} else {
|
||||
particle.subscriptions.Delete(client)
|
||||
}
|
||||
|
||||
x.trim(particle)
|
||||
return true
|
||||
}
|
||||
|
||||
// RetainMessage saves a message payload to the end of a topic address. Returns
|
||||
// 1 if a retained message was added, and -1 if the retained message was removed.
|
||||
// 0 is returned if sequential empty payloads are received.
|
||||
func (x *TopicsIndex) RetainMessage(pk packets.Packet) int64 {
|
||||
x.root.Lock()
|
||||
defer x.root.Unlock()
|
||||
|
||||
n := x.set(pk.TopicName, 0)
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
if len(pk.Payload) > 0 {
|
||||
n.retainPath = pk.TopicName
|
||||
x.Retained.Add(pk.TopicName, pk)
|
||||
return 1
|
||||
}
|
||||
|
||||
var out int64
|
||||
if pke, ok := x.Retained.Get(pk.TopicName); ok && len(pke.Payload) > 0 && pke.FixedHeader.Retain {
|
||||
out = -1 // if a retained packet existed, return -1
|
||||
}
|
||||
|
||||
n.retainPath = ""
|
||||
x.Retained.Delete(pk.TopicName) // [MQTT-3.3.1-6] [MQTT-3.3.1-7]
|
||||
x.trim(n)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// set creates a topic address in the index and returns the final particle.
|
||||
func (x *TopicsIndex) set(topic string, d int) *particle {
|
||||
var key string
|
||||
var hasNext = true
|
||||
n := x.root
|
||||
for hasNext {
|
||||
key, hasNext = isolateParticle(topic, d)
|
||||
d++
|
||||
|
||||
p := n.particles.get(key)
|
||||
if p == nil {
|
||||
p = newParticle(key, n)
|
||||
n.particles.add(p)
|
||||
}
|
||||
n = p
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// seek finds the particle at a specific index in a topic filter.
|
||||
func (x *TopicsIndex) seek(filter string, d int) *particle {
|
||||
var key string
|
||||
var hasNext = true
|
||||
n := x.root
|
||||
for hasNext {
|
||||
key, hasNext = isolateParticle(filter, d)
|
||||
n = n.particles.get(key)
|
||||
d++
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// trim removes empty filter particles from the index.
|
||||
func (x *TopicsIndex) trim(n *particle) {
|
||||
for n.parent != nil && n.retainPath == "" && n.particles.len()+n.subscriptions.Len()+n.shared.Len() == 0 {
|
||||
key := n.key
|
||||
n = n.parent
|
||||
n.particles.delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Messages returns a slice of any retained messages which match a filter.
|
||||
func (x *TopicsIndex) Messages(filter string) []packets.Packet {
|
||||
return x.scanMessages(filter, 0, nil, []packets.Packet{})
|
||||
}
|
||||
|
||||
// scanMessages returns all retained messages on topics matching a given filter.
|
||||
func (x *TopicsIndex) scanMessages(filter string, d int, n *particle, pks []packets.Packet) []packets.Packet {
|
||||
if n == nil {
|
||||
n = x.root
|
||||
}
|
||||
|
||||
if len(filter) == 0 || x.Retained.Len() == 0 {
|
||||
return pks
|
||||
}
|
||||
|
||||
if !strings.ContainsRune(filter, '#') && !strings.ContainsRune(filter, '+') {
|
||||
if pk, ok := x.Retained.Get(filter); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
return pks
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(filter, d)
|
||||
if key == "+" || key == "#" || d == -1 {
|
||||
for _, adjacent := range n.particles.getAll() {
|
||||
if d == 0 && adjacent.key == SysPrefix {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasNext {
|
||||
if adjacent.retainPath != "" {
|
||||
if pk, ok := x.Retained.Get(adjacent.retainPath); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasNext || (d >= 0 && key == "#") {
|
||||
pks = x.scanMessages(filter, d+1, adjacent, pks)
|
||||
}
|
||||
}
|
||||
return pks
|
||||
}
|
||||
|
||||
if particle := n.particles.get(key); particle != nil {
|
||||
if hasNext {
|
||||
return x.scanMessages(filter, d+1, particle, pks)
|
||||
}
|
||||
|
||||
if pk, ok := x.Retained.Get(particle.retainPath); ok {
|
||||
pks = append(pks, pk)
|
||||
}
|
||||
}
|
||||
|
||||
return pks
|
||||
}
|
||||
|
||||
// Subscribers returns a map of clients who are subscribed to matching filters,
|
||||
// their subscription ids and highest qos.
|
||||
func (x *TopicsIndex) Subscribers(topic string) *Subscribers {
|
||||
return x.scanSubscribers(topic, 0, nil, &Subscribers{
|
||||
Shared: map[string]map[string]packets.Subscription{},
|
||||
SharedSelected: map[string]packets.Subscription{},
|
||||
Subscriptions: map[string]packets.Subscription{},
|
||||
})
|
||||
}
|
||||
|
||||
// scanSubscribers returns a list of client subscriptions matching an indexed topic address.
|
||||
func (x *TopicsIndex) scanSubscribers(topic string, d int, n *particle, subs *Subscribers) *Subscribers {
|
||||
if n == nil {
|
||||
n = x.root
|
||||
}
|
||||
|
||||
if len(topic) == 0 {
|
||||
return subs
|
||||
}
|
||||
|
||||
key, hasNext := isolateParticle(topic, d)
|
||||
for _, partKey := range []string{key, "+", "#"} {
|
||||
if particle := n.particles.get(partKey); particle != nil { // [MQTT-3.3.2-3]
|
||||
x.gatherSubscriptions(topic, particle, subs)
|
||||
x.gatherSharedSubscriptions(particle, subs)
|
||||
if wild := particle.particles.get("#"); wild != nil && partKey != "#" && partKey != "+" {
|
||||
x.gatherSubscriptions(topic, wild, subs) // also match any subs where filter/# is filter as per 4.7.1.2
|
||||
}
|
||||
|
||||
if hasNext {
|
||||
x.scanSubscribers(topic, d+1, particle, subs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return subs
|
||||
}
|
||||
|
||||
// gatherSubscriptions collects any matching subscriptions, and gathers any identifiers or highest qos values.
|
||||
func (x *TopicsIndex) gatherSubscriptions(topic string, particle *particle, subs *Subscribers) {
|
||||
if subs.Subscriptions == nil {
|
||||
subs.Subscriptions = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for client, sub := range particle.subscriptions.GetAll() {
|
||||
if len(sub.Filter) > 0 && topic[0] == '$' && (sub.Filter[0] == '+' || sub.Filter[0] == '#') { // don't match $ topics with top level wildcards [MQTT-4.7.1-1] [MQTT-4.7.1-2]
|
||||
continue
|
||||
}
|
||||
|
||||
cls, ok := subs.Subscriptions[client]
|
||||
if !ok {
|
||||
cls = sub
|
||||
}
|
||||
|
||||
subs.Subscriptions[client] = cls.Merge(sub)
|
||||
}
|
||||
}
|
||||
|
||||
// gatherSharedSubscriptions gathers all shared subscriptions for a particle.
|
||||
func (x *TopicsIndex) gatherSharedSubscriptions(particle *particle, subs *Subscribers) {
|
||||
if subs.Shared == nil {
|
||||
subs.Shared = map[string]map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
for _, shares := range particle.shared.GetAll() {
|
||||
for client, sub := range shares {
|
||||
if _, ok := subs.Shared[sub.Filter]; !ok {
|
||||
subs.Shared[sub.Filter] = map[string]packets.Subscription{}
|
||||
}
|
||||
|
||||
subs.Shared[sub.Filter][client] = sub
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isolateParticle extracts a particle between d / and d+1 / without allocations.
|
||||
func isolateParticle(filter string, d int) (particle string, hasNext bool) {
|
||||
var next, end int
|
||||
for i := 0; end > -1 && i <= d; i++ {
|
||||
end = strings.IndexRune(filter, '/')
|
||||
|
||||
switch {
|
||||
case d > -1 && i == d && end > -1:
|
||||
hasNext = true
|
||||
particle = filter[next:end]
|
||||
case end > -1:
|
||||
hasNext = false
|
||||
filter = filter[end+1:]
|
||||
default:
|
||||
hasNext = false
|
||||
particle = filter[next:]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsSharedFilter returns true if the filter uses the share prefix.
|
||||
func IsSharedFilter(filter string) bool {
|
||||
prefix, _ := isolateParticle(filter, 0)
|
||||
return strings.EqualFold(prefix, SharePrefix)
|
||||
}
|
||||
|
||||
// IsValidFilter returns true if the filter is valid.
|
||||
func IsValidFilter(filter string, forPublish bool) bool {
|
||||
if !forPublish && len(filter) == 0 { // publishing can accept zero-length topic filter if topic alias exists, so we don't enforce for publihs.
|
||||
return false // [MQTT-4.7.3-1]
|
||||
}
|
||||
|
||||
if forPublish {
|
||||
if len(filter) >= len(SysPrefix) && strings.EqualFold(filter[0:len(SysPrefix)], SysPrefix) {
|
||||
// 4.7.2 Non-normative - The Server SHOULD prevent Clients from using such Topic Names [$SYS] to exchange messages with other Clients.
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.ContainsRune(filter, '+') || strings.ContainsRune(filter, '#') {
|
||||
return false //[MQTT-3.3.2-2]
|
||||
}
|
||||
}
|
||||
|
||||
wildhash := strings.IndexRune(filter, '#')
|
||||
if wildhash >= 0 && wildhash != len(filter)-1 { // [MQTT-4.7.1-2]
|
||||
return false
|
||||
}
|
||||
|
||||
prefix, hasNext := isolateParticle(filter, 0)
|
||||
if !hasNext && strings.EqualFold(prefix, SharePrefix) {
|
||||
return false // [MQTT-4.8.2-1]
|
||||
}
|
||||
|
||||
if hasNext && strings.EqualFold(prefix, SharePrefix) {
|
||||
group, hasNext := isolateParticle(filter, 1)
|
||||
if !hasNext {
|
||||
return false // [MQTT-4.8.2-1]
|
||||
}
|
||||
|
||||
if strings.ContainsRune(group, '+') || strings.ContainsRune(group, '#') {
|
||||
return false // [MQTT-4.8.2-2]
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// particle is a child node on the tree.
|
||||
type particle struct {
|
||||
key string // the key of the particle
|
||||
parent *particle // a pointer to the parent of the particle
|
||||
particles particles // a map of child particles
|
||||
subscriptions *Subscriptions // a map of subscriptions made by clients to this ending address
|
||||
shared *SharedSubscriptions // a map of shared subscriptions keyed on group name
|
||||
retainPath string // path of a retained message
|
||||
sync.Mutex // mutex for when making changes to the particle
|
||||
}
|
||||
|
||||
// newParticle returns a pointer to a new instance of particle.
|
||||
func newParticle(key string, parent *particle) *particle {
|
||||
return &particle{
|
||||
key: key,
|
||||
parent: parent,
|
||||
particles: newParticles(),
|
||||
subscriptions: NewSubscriptions(),
|
||||
shared: NewSharedSubscriptions(),
|
||||
}
|
||||
}
|
||||
|
||||
// particles is a concurrency safe map of particles.
|
||||
type particles struct {
|
||||
internal map[string]*particle
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// newParticles returns a map of particles.
|
||||
func newParticles() particles {
|
||||
return particles{
|
||||
internal: map[string]*particle{},
|
||||
}
|
||||
}
|
||||
|
||||
// add adds a new particle.
|
||||
func (p *particles) add(val *particle) {
|
||||
p.Lock()
|
||||
p.internal[val.key] = val
|
||||
p.Unlock()
|
||||
}
|
||||
|
||||
// getAll returns all particles.
|
||||
func (p *particles) getAll() map[string]*particle {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
m := map[string]*particle{}
|
||||
for k, v := range p.internal {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// get returns a particle by id (key).
|
||||
func (p *particles) get(id string) *particle {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
return p.internal[id]
|
||||
}
|
||||
|
||||
// len returns the number of particles.
|
||||
func (p *particles) len() int {
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
val := len(p.internal)
|
||||
return val
|
||||
}
|
||||
|
||||
// delete removes a particle.
|
||||
func (p *particles) delete(id string) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
delete(p.internal, id)
|
||||
}
|
||||
842
backend/services/mochi/topics_test.go
Normal file
842
backend/services/mochi/topics_test.go
Normal file
|
|
@ -0,0 +1,842 @@
|
|||
// SPDX-License-Identifier: MIT
|
||||
// SPDX-FileCopyrightText: 2022 J. Blake / mochi-co
|
||||
// SPDX-FileContributor: mochi-co
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mochi-co/mqtt/v2/packets"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testGroup = "testgroup"
|
||||
otherGroup = "other"
|
||||
)
|
||||
|
||||
func TestNewSharedSubscriptions(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsAdd(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsGet(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 2})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
|
||||
sub, ok := s.Get(testGroup, "cl2")
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, byte(2), sub.Qos)
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsGetAll(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
|
||||
s.Add(otherGroup, "cl3", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
require.Contains(t, s.internal, otherGroup)
|
||||
require.Contains(t, s.internal[otherGroup], "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 2)
|
||||
require.Len(t, subs[testGroup], 2)
|
||||
require.Len(t, subs[otherGroup], 1)
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsLen(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 0})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 1})
|
||||
s.Add(otherGroup, "cl2", packets.Subscription{Qos: 1})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
require.Contains(t, s.internal, otherGroup)
|
||||
require.Contains(t, s.internal[otherGroup], "cl2")
|
||||
require.Equal(t, 3, s.Len())
|
||||
require.Equal(t, 2, s.GroupLen())
|
||||
}
|
||||
|
||||
func TestSharedSubscriptionsDelete(t *testing.T) {
|
||||
s := NewSharedSubscriptions()
|
||||
s.Add(testGroup, "cl1", packets.Subscription{Qos: 1})
|
||||
s.Add(testGroup, "cl2", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl1")
|
||||
require.Contains(t, s.internal, testGroup)
|
||||
require.Contains(t, s.internal[testGroup], "cl2")
|
||||
|
||||
require.Equal(t, 2, s.Len())
|
||||
|
||||
s.Delete(testGroup, "cl1")
|
||||
_, ok := s.Get(testGroup, "cl1")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 1, s.GroupLen())
|
||||
require.Equal(t, 1, s.Len())
|
||||
|
||||
s.Delete(testGroup, "cl2")
|
||||
_, ok = s.Get(testGroup, "cl2")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, 0, s.GroupLen())
|
||||
require.Equal(t, 0, s.Len())
|
||||
}
|
||||
|
||||
func TestNewSubscriptions(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
require.NotNil(t, s.internal)
|
||||
}
|
||||
|
||||
func TestSubscriptionsAdd(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
}
|
||||
|
||||
func TestSubscriptionsGet(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 2})
|
||||
s.Add("cl2", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
|
||||
sub, ok := s.Get("cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, byte(2), sub.Qos)
|
||||
}
|
||||
|
||||
func TestSubscriptionsGetAll(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 0})
|
||||
s.Add("cl2", packets.Subscription{Qos: 1})
|
||||
s.Add("cl3", packets.Subscription{Qos: 2})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Contains(t, s.internal, "cl3")
|
||||
|
||||
subs := s.GetAll()
|
||||
require.Len(t, subs, 3)
|
||||
}
|
||||
|
||||
func TestSubscriptionsLen(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 0})
|
||||
s.Add("cl2", packets.Subscription{Qos: 1})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
require.Contains(t, s.internal, "cl2")
|
||||
require.Equal(t, 2, s.Len())
|
||||
}
|
||||
|
||||
func TestSubscriptionsDelete(t *testing.T) {
|
||||
s := NewSubscriptions()
|
||||
s.Add("cl1", packets.Subscription{Qos: 1})
|
||||
require.Contains(t, s.internal, "cl1")
|
||||
|
||||
s.Delete("cl1")
|
||||
_, ok := s.Get("cl1")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestNewTopicsIndex(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
require.NotNil(t, index)
|
||||
require.NotNil(t, index.root)
|
||||
}
|
||||
|
||||
func BenchmarkNewTopicsIndex(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
NewTopicsIndex()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
tt := []struct {
|
||||
desc string
|
||||
client string
|
||||
filter string
|
||||
subscription packets.Subscription
|
||||
wasNew bool
|
||||
}{
|
||||
{
|
||||
desc: "subscribe",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "a/b/c", Qos: 2},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "subscribe existed",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "a/b/c", Qos: 1},
|
||||
wasNew: false,
|
||||
},
|
||||
{
|
||||
desc: "subscribe case sensitive didnt exist",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "A/B/c", Qos: 1},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard+ sub",
|
||||
client: "cl1",
|
||||
|
||||
subscription: packets.Subscription{Filter: "d/+"},
|
||||
wasNew: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard# sub",
|
||||
client: "cl1",
|
||||
subscription: packets.Subscription{Filter: "d/e/#"},
|
||||
wasNew: true,
|
||||
},
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
for _, tx := range tt {
|
||||
t.Run(tx.desc, func(t *testing.T) {
|
||||
require.Equal(t, tx.wasNew, index.Subscribe(tx.client, tx.subscription))
|
||||
})
|
||||
}
|
||||
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
client, exists := final.subscriptions.Get("cl1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(1), client.Qos)
|
||||
}
|
||||
|
||||
func TestSubscribeShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c", Qos: 2})
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
client, exists := final.shared.Get("tmp", "cl1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(2), client.Qos)
|
||||
require.Equal(t, 0, final.subscriptions.Len())
|
||||
require.Equal(t, 1, final.shared.Len())
|
||||
}
|
||||
|
||||
func BenchmarkSubscribe(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribe("client-1", packets.Subscription{Filter: "a/b/c"})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSubscribeShared(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribe("client-1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/d", Qos: 1})
|
||||
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/+/d", Qos: 1})
|
||||
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "d/e/f", Qos: 1})
|
||||
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl2", packets.Subscription{Filter: "d/e/f", Qos: 1})
|
||||
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
index.Subscribe("cl3", packets.Subscription{Filter: "#", Qos: 2})
|
||||
client, exists = index.root.particles.get("#").subscriptions.Get("cl3")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
ok := index.Unsubscribe("a/b/c/d", "cl1")
|
||||
require.True(t, ok)
|
||||
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
client, exists = index.root.particles.get("a").particles.get("b").particles.get("+").particles.get("d").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
ok = index.Unsubscribe("d/e/f", "cl1")
|
||||
require.True(t, ok)
|
||||
|
||||
require.Equal(t, 1, index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Len())
|
||||
client, exists = index.root.particles.get("d").particles.get("e").particles.get("f").subscriptions.Get("cl2")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
|
||||
ok = index.Unsubscribe("fdasfdas/dfsfads/sa", "nobody")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestUnsubscribeNoCascade(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/e/e"})
|
||||
|
||||
ok := index.Unsubscribe("a/b/c/e/e", "cl1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, index.root.particles.len())
|
||||
|
||||
client, exists := index.root.particles.get("a").particles.get("b").particles.get("c").subscriptions.Get("cl1")
|
||||
require.NotNil(t, client)
|
||||
require.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUnsubscribeShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "$SHARE/tmp/a/b/c", Qos: 2})
|
||||
final := index.root.particles.get("a").particles.get("b").particles.get("c")
|
||||
require.NotNil(t, final)
|
||||
client, exists := final.shared.Get("tmp", "cl1")
|
||||
require.True(t, exists)
|
||||
require.Equal(t, byte(2), client.Qos)
|
||||
|
||||
require.True(t, index.Unsubscribe("$SHARE/tmp/a/b/c", "cl1"))
|
||||
_, exists = final.shared.Get("tmp", "cl1")
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func BenchmarkUnsubscribe(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
b.StopTimer()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
b.StartTimer()
|
||||
index.Unsubscribe("a/b/c", "cl1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexSeek(t *testing.T) {
|
||||
filter := "a/b/c/d/e/f"
|
||||
index := NewTopicsIndex()
|
||||
k1 := index.set(filter, 0)
|
||||
require.Equal(t, "f", k1.key)
|
||||
k1.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
require.Equal(t, k1, index.seek(filter, 0))
|
||||
require.Nil(t, index.seek("d/e/f", 0))
|
||||
}
|
||||
|
||||
func TestIndexTrim(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
k1 := index.set("a/b/c", 0)
|
||||
require.Equal(t, "c", k1.key)
|
||||
k1.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
k2 := index.set("a/b/c/d/e/f", 0)
|
||||
require.Equal(t, "f", k2.key)
|
||||
k2.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
k3 := index.set("a/b", 0)
|
||||
require.Equal(t, "b", k3.key)
|
||||
k3.subscriptions.Add("cl1", packets.Subscription{})
|
||||
|
||||
index.trim(k2)
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d").particles.get("e").particles.get("f"))
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b"))
|
||||
|
||||
k2.subscriptions.Delete("cl1")
|
||||
index.trim(k2)
|
||||
|
||||
require.Nil(t, index.root.particles.get("a").particles.get("b").particles.get("c").particles.get("d"))
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
|
||||
k1.subscriptions.Delete("cl1")
|
||||
k3.subscriptions.Delete("cl1")
|
||||
index.trim(k2)
|
||||
require.Nil(t, index.root.particles.get("a"))
|
||||
}
|
||||
|
||||
func TestIndexSet(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
child := index.set("a/b/c", 0)
|
||||
require.Equal(t, "c", child.key)
|
||||
require.NotNil(t, index.root.particles.get("a").particles.get("b").particles.get("c"))
|
||||
|
||||
child = index.set("a/b/c/d/e", 0)
|
||||
require.Equal(t, "e", child.key)
|
||||
|
||||
child = index.set("a/b/c/c/a", 0)
|
||||
require.Equal(t, "a", child.key)
|
||||
}
|
||||
|
||||
func TestIndexSetPrefixed(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
child := index.set("/c", 0)
|
||||
require.Equal(t, "c", child.key)
|
||||
require.NotNil(t, index.root.particles.get("").particles.get("c"))
|
||||
}
|
||||
|
||||
func BenchmarkIndexSet(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.set("a/b/c", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetainMessage(t *testing.T) {
|
||||
pk := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{Retain: true},
|
||||
TopicName: "a/b/c",
|
||||
Payload: []byte("hello"),
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
r := index.RetainMessage(pk)
|
||||
require.Equal(t, int64(1), r)
|
||||
pke, ok := index.Retained.Get(pk.TopicName)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, pk, pke)
|
||||
|
||||
pk2 := packets.Packet{
|
||||
FixedHeader: packets.FixedHeader{Retain: true},
|
||||
TopicName: "a/b/d/f",
|
||||
Payload: []byte("hello"),
|
||||
}
|
||||
r = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), r)
|
||||
// The same message already exists, but we're not doing a deep-copy check, so it's considered to be a new message.
|
||||
r = index.RetainMessage(pk2)
|
||||
require.Equal(t, int64(1), r)
|
||||
|
||||
// Clear existing retained
|
||||
pk3 := packets.Packet{TopicName: "a/b/c", Payload: []byte{}}
|
||||
r = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(-1), r)
|
||||
_, ok = index.Retained.Get(pk.TopicName)
|
||||
require.False(t, ok)
|
||||
|
||||
// Clear no retained
|
||||
r = index.RetainMessage(pk3)
|
||||
require.Equal(t, int64(0), r)
|
||||
}
|
||||
|
||||
func BenchmarkRetainMessage(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsolateParticle(t *testing.T) {
|
||||
particle, hasNext := isolateParticle("path/to/my/mqtt", 0)
|
||||
require.Equal(t, "path", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 1)
|
||||
require.Equal(t, "to", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 2)
|
||||
require.Equal(t, "my", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("path/to/my/mqtt", 3)
|
||||
require.Equal(t, "mqtt", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
|
||||
particle, hasNext = isolateParticle("/path/", 0)
|
||||
require.Equal(t, "", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("/path/", 1)
|
||||
require.Equal(t, "path", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("/path/", 2)
|
||||
require.Equal(t, "", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
|
||||
particle, hasNext = isolateParticle("a/b/c/+/+", 3)
|
||||
require.Equal(t, "+", particle)
|
||||
require.Equal(t, true, hasNext)
|
||||
particle, hasNext = isolateParticle("a/b/c/+/+", 4)
|
||||
require.Equal(t, "+", particle)
|
||||
require.Equal(t, false, hasNext)
|
||||
}
|
||||
|
||||
func BenchmarkIsolateParticle(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
isolateParticle("path/to/my/mqtt", 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSubscribers(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c", Identifier: 22})
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: "a/b/c/d/e/f"})
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 2, Filter: "a/b/c/d/+/f"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "a/#"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 1, Filter: "a/b/c"})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "a/b/+", Identifier: 77})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "d/e/f", Identifier: 7237})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 2, Filter: "$SYS/uptime", Identifier: 3})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: "+/b", Identifier: 234})
|
||||
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: "#", Identifier: 5})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: "$SYS/test", Identifier: 2})
|
||||
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 4, len(subs.Subscriptions))
|
||||
require.Contains(t, subs.Subscriptions, "cl1")
|
||||
require.Contains(t, subs.Subscriptions, "cl2")
|
||||
require.Contains(t, subs.Subscriptions, "cl3")
|
||||
require.Contains(t, subs.Subscriptions, "cl4")
|
||||
|
||||
require.Equal(t, byte(1), subs.Subscriptions["cl1"].Qos)
|
||||
require.Equal(t, byte(2), subs.Subscriptions["cl2"].Qos)
|
||||
require.Equal(t, byte(1), subs.Subscriptions["cl3"].Qos)
|
||||
require.Equal(t, byte(0), subs.Subscriptions["cl4"].Qos)
|
||||
|
||||
require.Equal(t, 22, subs.Subscriptions["cl1"].Identifiers["a/b/c"])
|
||||
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/#"])
|
||||
require.Equal(t, 77, subs.Subscriptions["cl2"].Identifiers["a/b/+"])
|
||||
require.Equal(t, 0, subs.Subscriptions["cl2"].Identifiers["a/b/c"])
|
||||
require.Equal(t, 234, subs.Subscriptions["cl3"].Identifiers["+/b"])
|
||||
require.Equal(t, 5, subs.Subscriptions["cl4"].Identifiers["#"])
|
||||
|
||||
subs = index.scanSubscribers("", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 0, len(subs.Subscriptions))
|
||||
}
|
||||
|
||||
func TestScanSubscribersShared(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 10})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 200})
|
||||
index.Subscribe("cl4", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/+", Identifier: 201})
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 3, len(subs.Shared))
|
||||
}
|
||||
|
||||
func TestSelectSharedSubscriber(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110})
|
||||
index.Subscribe("cl1b", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 111})
|
||||
index.Subscribe("cl2", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 112})
|
||||
index.Subscribe("cl3", packets.Subscription{Qos: 0, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 113})
|
||||
subs := index.scanSubscribers("a/b/c", 0, nil, new(Subscribers))
|
||||
require.Equal(t, 2, len(subs.Shared))
|
||||
require.Contains(t, subs.Shared, SharePrefix+"/tmp/a/b/c")
|
||||
require.Contains(t, subs.Shared, SharePrefix+"/tmp2/a/b/c")
|
||||
require.Len(t, subs.Shared[SharePrefix+"/tmp/a/b/c"], 3)
|
||||
require.Len(t, subs.Shared[SharePrefix+"/tmp2/a/b/c"], 1)
|
||||
subs.SelectShared()
|
||||
require.Len(t, subs.SharedSelected, 2)
|
||||
}
|
||||
|
||||
func TestMergeSharedSelected(t *testing.T) {
|
||||
s := &Subscribers{
|
||||
SharedSelected: map[string]packets.Subscription{
|
||||
"cl1": {Qos: 1, Filter: SharePrefix + "/tmp/a/b/c", Identifier: 110},
|
||||
"cl2": {Qos: 1, Filter: SharePrefix + "/tmp2/a/b/c", Identifier: 111},
|
||||
},
|
||||
Subscriptions: map[string]packets.Subscription{
|
||||
"cl2": {Qos: 1, Filter: "a/b/c", Identifier: 112},
|
||||
},
|
||||
}
|
||||
|
||||
s.MergeSharedSelected()
|
||||
|
||||
require.Equal(t, 2, len(s.Subscriptions))
|
||||
require.Contains(t, s.Subscriptions, "cl1")
|
||||
require.Contains(t, s.Subscriptions, "cl2")
|
||||
require.EqualValues(t, map[string]int{
|
||||
SharePrefix + "/tmp2/a/b/c": 111,
|
||||
"a/b/c": 112,
|
||||
}, s.Subscriptions["cl2"].Identifiers)
|
||||
}
|
||||
|
||||
func TestSubscribersFind(t *testing.T) {
|
||||
tt := []struct {
|
||||
filter string
|
||||
topic string
|
||||
matched bool
|
||||
}{
|
||||
{filter: "a", topic: "a", matched: true},
|
||||
{filter: "a/", topic: "a", matched: false},
|
||||
{filter: "a/", topic: "a/", matched: true},
|
||||
{filter: "/a", topic: "/a", matched: true},
|
||||
{filter: "path/to/my/mqtt", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "path/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "+/to/+/mqtt", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "#", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "+/+/+/+", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "+/+/+/#", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "zen/#", topic: "zen", matched: true}, // as per 4.7.1.2
|
||||
{filter: "trailing-end/#", topic: "trailing-end/", matched: true},
|
||||
{filter: "+/prefixed", topic: "/prefixed", matched: true},
|
||||
{filter: "+/+/#", topic: "path/to/my/mqtt", matched: true},
|
||||
{filter: "path/to/", topic: "path/to/my/mqtt", matched: false},
|
||||
{filter: "#/stuff", topic: "path/to/my/mqtt", matched: false},
|
||||
{filter: "#", topic: "$SYS/info", matched: false},
|
||||
{filter: "$SYS/#", topic: "$SYS/info", matched: true},
|
||||
{filter: "+/info", topic: "$SYS/info", matched: false},
|
||||
}
|
||||
|
||||
for _, tx := range tt {
|
||||
t.Run("filter:'"+tx.filter+"' vs topic:'"+tx.topic+"'", func(t *testing.T) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: tx.filter})
|
||||
subs := index.Subscribers(tx.topic)
|
||||
require.Equal(t, tx.matched, len(subs.Subscriptions) == 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSubscribers(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c"})
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/+/c"})
|
||||
index.Subscribe("cl1", packets.Subscription{Filter: "a/b/c/+"})
|
||||
index.Subscribe("cl2", packets.Subscription{Filter: "a/b/c/d"})
|
||||
index.Subscribe("cl3", packets.Subscription{Filter: "#"})
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Subscribers("a/b/c")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPattern(t *testing.T) {
|
||||
payload := []byte("hello")
|
||||
fh := packets.FixedHeader{Type: packets.Publish, Retain: true}
|
||||
|
||||
pks := []packets.Packet{
|
||||
{TopicName: "$SYS/uptime", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "$SYS/info", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "a/b/c/d", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "a/b/c/e", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "a/b/d/f", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "q/w/e/r/t/y", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "q/x/e/r/t/o", Payload: payload, FixedHeader: fh},
|
||||
{TopicName: "asdf", Payload: payload, FixedHeader: fh},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
filter string
|
||||
len int
|
||||
}{
|
||||
{"a/b/c/d", 1},
|
||||
{"$SYS/+", 2},
|
||||
{"$SYS/#", 2},
|
||||
{"#", len(pks) - 2},
|
||||
{"a/b/c/+", 2},
|
||||
{"a/+/c/+", 2},
|
||||
{"+/+/+/d", 1},
|
||||
{"q/w/e/#", 1},
|
||||
{"+/+/+/+", 3},
|
||||
{"q/#", 2},
|
||||
{"asdf", 1},
|
||||
{"", 0},
|
||||
{"#", 6},
|
||||
}
|
||||
|
||||
index := NewTopicsIndex()
|
||||
for _, pk := range pks {
|
||||
index.RetainMessage(pk)
|
||||
}
|
||||
|
||||
for _, tx := range tt {
|
||||
t.Run("filter:'"+tx.filter, func(t *testing.T) {
|
||||
messages := index.Messages(tx.filter)
|
||||
require.Equal(t, tx.len, len(messages))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessages(b *testing.B) {
|
||||
index := NewTopicsIndex()
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b/c/d"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "a/b/d/e/f"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "d/e/f/g"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "$SYS/info"})
|
||||
index.RetainMessage(packets.Packet{TopicName: "q/w/e/r/t/y"})
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
index.Messages("+/b/c/+")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewParticles(t *testing.T) {
|
||||
cl := newParticles()
|
||||
require.NotNil(t, cl.internal)
|
||||
}
|
||||
|
||||
func TestParticlesAdd(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
}
|
||||
|
||||
func TestParticlesGet(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
p.add(&particle{key: "b"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
require.Contains(t, p.internal, "b")
|
||||
|
||||
particle := p.get("a")
|
||||
require.NotNil(t, particle)
|
||||
require.Equal(t, "a", particle.key)
|
||||
}
|
||||
|
||||
func TestParticlesGetAll(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
p.add(&particle{key: "b"})
|
||||
p.add(&particle{key: "c"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
require.Contains(t, p.internal, "b")
|
||||
require.Contains(t, p.internal, "c")
|
||||
|
||||
particles := p.getAll()
|
||||
require.Len(t, particles, 3)
|
||||
}
|
||||
|
||||
func TestParticlesLen(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
p.add(&particle{key: "b"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
require.Contains(t, p.internal, "b")
|
||||
require.Equal(t, 2, p.len())
|
||||
}
|
||||
|
||||
func TestParticlesDelete(t *testing.T) {
|
||||
p := newParticles()
|
||||
p.add(&particle{key: "a"})
|
||||
require.Contains(t, p.internal, "a")
|
||||
|
||||
p.delete("a")
|
||||
particle := p.get("a")
|
||||
require.Nil(t, particle)
|
||||
}
|
||||
|
||||
func TestIsValid(t *testing.T) {
|
||||
require.True(t, IsValidFilter("a/b/c", false))
|
||||
require.True(t, IsValidFilter("a/b//c", false))
|
||||
require.True(t, IsValidFilter("$SYS", false))
|
||||
require.True(t, IsValidFilter("$SYS/info", false))
|
||||
require.True(t, IsValidFilter("$sys/info", false))
|
||||
require.True(t, IsValidFilter("abc/#", false))
|
||||
require.False(t, IsValidFilter("", false))
|
||||
require.False(t, IsValidFilter(SharePrefix, false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/b+/", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/+", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/#", false))
|
||||
require.False(t, IsValidFilter(SharePrefix+"/#/", false))
|
||||
require.False(t, IsValidFilter("a/#/c", false))
|
||||
}
|
||||
|
||||
func TestIsValidForPublish(t *testing.T) {
|
||||
require.True(t, IsValidFilter("", true))
|
||||
require.True(t, IsValidFilter("a/b/c", true))
|
||||
require.False(t, IsValidFilter("a/b/+/d", true))
|
||||
require.False(t, IsValidFilter("a/b/#", true))
|
||||
require.False(t, IsValidFilter("$SYS/info", true))
|
||||
}
|
||||
|
||||
func TestIsSharedFilter(t *testing.T) {
|
||||
require.True(t, IsSharedFilter(SharePrefix+"/tmp/a/b/c"))
|
||||
require.False(t, IsSharedFilter("a/b/c"))
|
||||
}
|
||||
|
||||
func TestNewInboundAliases(t *testing.T) {
|
||||
a := NewInboundTopicAliases(5)
|
||||
require.NotNil(t, a)
|
||||
require.NotNil(t, a.internal)
|
||||
require.Equal(t, uint16(5), a.maximum)
|
||||
}
|
||||
|
||||
func TestInboundAliasesSet(t *testing.T) {
|
||||
topic := "test"
|
||||
id := uint16(1)
|
||||
a := NewInboundTopicAliases(5)
|
||||
require.Equal(t, topic, a.Set(id, topic))
|
||||
require.Contains(t, a.internal, id)
|
||||
require.Equal(t, a.internal[id], topic)
|
||||
|
||||
require.Equal(t, topic, a.Set(id, ""))
|
||||
}
|
||||
|
||||
func TestInboundAliasesSetMaxZero(t *testing.T) {
|
||||
topic := "test"
|
||||
id := uint16(1)
|
||||
a := NewInboundTopicAliases(0)
|
||||
require.Equal(t, topic, a.Set(id, topic))
|
||||
require.NotContains(t, a.internal, id)
|
||||
}
|
||||
|
||||
func TestNewOutboundAliases(t *testing.T) {
|
||||
a := NewOutboundTopicAliases(5)
|
||||
require.NotNil(t, a)
|
||||
require.NotNil(t, a.internal)
|
||||
require.Equal(t, uint16(5), a.maximum)
|
||||
require.Equal(t, uint32(0), a.cursor)
|
||||
}
|
||||
|
||||
func TestOutboundAliasesSet(t *testing.T) {
|
||||
a := NewOutboundTopicAliases(3)
|
||||
n, ok := a.Set("t1")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(1), n)
|
||||
|
||||
n, ok = a.Set("t2")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(2), n)
|
||||
|
||||
n, ok = a.Set("t3")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(3), n)
|
||||
|
||||
n, ok = a.Set("t4")
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), n)
|
||||
|
||||
n, ok = a.Set("t2")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, uint16(2), n)
|
||||
}
|
||||
|
||||
func TestOutboundAliasesSetMaxZero(t *testing.T) {
|
||||
topic := "test"
|
||||
a := NewOutboundTopicAliases(0)
|
||||
n, ok := a.Set(topic)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, uint16(0), n)
|
||||
}
|
||||
|
||||
func TestNewTopicAliases(t *testing.T) {
|
||||
a := NewTopicAliases(5)
|
||||
require.NotNil(t, a.Inbound)
|
||||
require.Equal(t, uint16(5), a.Inbound.maximum)
|
||||
require.NotNil(t, a.Outbound)
|
||||
require.Equal(t, uint16(5), a.Outbound.maximum)
|
||||
}
|
||||
1
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
Normal file
1
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/.travis.yml
generated
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
language: go
|
||||
35
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
Normal file
35
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/LICENSE
generated
vendored
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
bbloom.go
|
||||
|
||||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
siphash.go
|
||||
|
||||
// https://github.com/dchest/siphash
|
||||
//
|
||||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
131
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
Normal file
131
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/README.md
generated
vendored
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
## bbloom: a bitset Bloom filter for go/golang
|
||||
===
|
||||
|
||||
[](http://travis-ci.org/AndreasBriese/bbloom)
|
||||
|
||||
package implements a fast bloom filter with real 'bitset' and JSONMarshal/JSONUnmarshal to store/reload the Bloom filter.
|
||||
|
||||
NOTE: the package uses unsafe.Pointer to set and read the bits from the bitset. If you're uncomfortable with using the unsafe package, please consider using my bloom filter package at github.com/AndreasBriese/bloom
|
||||
|
||||
===
|
||||
|
||||
changelog 11/2015: new thread safe methods AddTS(), HasTS(), AddIfNotHasTS() following a suggestion from Srdjan Marinovic (github @a-little-srdjan), who used this to code a bloomfilter cache.
|
||||
|
||||
This bloom filter was developed to strengthen a website-log database and was tested and optimized for this log-entry mask: "2014/%02i/%02i %02i:%02i:%02i /info.html".
|
||||
Nonetheless bbloom should work with any other form of entries.
|
||||
|
||||
~~Hash function is a modified Berkeley DB sdbm hash (to optimize for smaller strings). sdbm http://www.cse.yorku.ca/~oz/hash.html~~
|
||||
|
||||
Found sipHash (SipHash-2-4, a fast short-input PRF created by Jean-Philippe Aumasson and Daniel J. Bernstein.) to be about as fast. sipHash had been ported by Dimtry Chestnyk to Go (github.com/dchest/siphash )
|
||||
|
||||
Minimum hashset size is: 512 ([4]uint64; will be set automatically).
|
||||
|
||||
###install
|
||||
|
||||
```sh
|
||||
go get github.com/AndreasBriese/bbloom
|
||||
```
|
||||
|
||||
###test
|
||||
+ change to folder ../bbloom
|
||||
+ create wordlist in file "words.txt" (you might use `python permut.py`)
|
||||
+ run 'go test -bench=.' within the folder
|
||||
|
||||
```go
|
||||
go test -bench=.
|
||||
```
|
||||
|
||||
~~If you've installed the GOCONVEY TDD-framework http://goconvey.co/ you can run the tests automatically.~~
|
||||
|
||||
using go's testing framework now (have in mind that the op timing is related to 65536 operations of Add, Has, AddIfNotHas respectively)
|
||||
|
||||
### usage
|
||||
|
||||
after installation add
|
||||
|
||||
```go
|
||||
import (
|
||||
...
|
||||
"github.com/AndreasBriese/bbloom"
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
at your header. In the program use
|
||||
|
||||
```go
|
||||
// create a bloom filter for 65536 items and 1 % wrong-positive ratio
|
||||
bf := bbloom.New(float64(1<<16), float64(0.01))
|
||||
|
||||
// or
|
||||
// create a bloom filter with 650000 for 65536 items and 7 locs per hash explicitly
|
||||
// bf = bbloom.New(float64(650000), float64(7))
|
||||
// or
|
||||
bf = bbloom.New(650000.0, 7.0)
|
||||
|
||||
// add one item
|
||||
bf.Add([]byte("butter"))
|
||||
|
||||
// Number of elements added is exposed now
|
||||
// Note: ElemNum will not be included in JSON export (for compatability to older version)
|
||||
nOfElementsInFilter := bf.ElemNum
|
||||
|
||||
// check if item is in the filter
|
||||
isIn := bf.Has([]byte("butter")) // should be true
|
||||
isNotIn := bf.Has([]byte("Butter")) // should be false
|
||||
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added := bf.AddIfNotHas([]byte("butter")) // should be false because 'butter' is already in the set
|
||||
added = bf.AddIfNotHas([]byte("buTTer")) // should be true because 'buTTer' is new
|
||||
|
||||
// thread safe versions for concurrent use: AddTS, HasTS, AddIfNotHasTS
|
||||
// add one item
|
||||
bf.AddTS([]byte("peanutbutter"))
|
||||
// check if item is in the filter
|
||||
isIn = bf.HasTS([]byte("peanutbutter")) // should be true
|
||||
isNotIn = bf.HasTS([]byte("peanutButter")) // should be false
|
||||
// 'add only if item is new' to the bloomfilter
|
||||
added = bf.AddIfNotHasTS([]byte("butter")) // should be false because 'peanutbutter' is already in the set
|
||||
added = bf.AddIfNotHasTS([]byte("peanutbuTTer")) // should be true because 'penutbuTTer' is new
|
||||
|
||||
// convert to JSON ([]byte)
|
||||
Json := bf.JSONMarshal()
|
||||
|
||||
// bloomfilters Mutex is exposed for external un-/locking
|
||||
// i.e. mutex lock while doing JSON conversion
|
||||
bf.Mtx.Lock()
|
||||
Json = bf.JSONMarshal()
|
||||
bf.Mtx.Unlock()
|
||||
|
||||
// restore a bloom filter from storage
|
||||
bfNew := bbloom.JSONUnmarshal(Json)
|
||||
|
||||
isInNew := bfNew.Has([]byte("butter")) // should be true
|
||||
isNotInNew := bfNew.Has([]byte("Butter")) // should be false
|
||||
|
||||
```
|
||||
|
||||
to work with the bloom filter.
|
||||
|
||||
### why 'fast'?
|
||||
|
||||
It's about 3 times faster than William Fitzgeralds bitset bloom filter https://github.com/willf/bloom . And it is about so fast as my []bool set variant for Boom filters (see https://github.com/AndreasBriese/bloom ) but having a 8times smaller memory footprint:
|
||||
|
||||
|
||||
Bloom filter (filter size 524288, 7 hashlocs)
|
||||
github.com/AndreasBriese/bbloom 'Add' 65536 items (10 repetitions): 6595800 ns (100 ns/op)
|
||||
github.com/AndreasBriese/bbloom 'Has' 65536 items (10 repetitions): 5986600 ns (91 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Add' 65536 items (10 repetitions): 6304684 ns (96 ns/op)
|
||||
github.com/AndreasBriese/bloom 'Has' 65536 items (10 repetitions): 6568663 ns (100 ns/op)
|
||||
|
||||
github.com/willf/bloom 'Add' 65536 items (10 repetitions): 24367224 ns (371 ns/op)
|
||||
github.com/willf/bloom 'Test' 65536 items (10 repetitions): 21881142 ns (333 ns/op)
|
||||
github.com/dataence/bloom/standard 'Add' 65536 items (10 repetitions): 23041644 ns (351 ns/op)
|
||||
github.com/dataence/bloom/standard 'Check' 65536 items (10 repetitions): 19153133 ns (292 ns/op)
|
||||
github.com/cabello/bloom 'Add' 65536 items (10 repetitions): 131921507 ns (2012 ns/op)
|
||||
github.com/cabello/bloom 'Contains' 65536 items (10 repetitions): 131108962 ns (2000 ns/op)
|
||||
|
||||
(on MBPro15 OSX10.8.5 i7 4Core 2.4Ghz)
|
||||
|
||||
|
||||
With 32bit bloom filters (bloom32) using modified sdbm, bloom32 does hashing with only 2 bit shifts, one xor and one substraction per byte. smdb is about as fast as fnv64a but gives less collisions with the dataset (see mask above). bloom.New(float64(10 * 1<<16),float64(7)) populated with 1<<16 random items from the dataset (see above) and tested against the rest results in less than 0.05% collisions.
|
||||
284
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
Normal file
284
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/bbloom.go
generated
vendored
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
// The MIT License (MIT)
|
||||
// Copyright (c) 2014 Andreas Briese, eduToolbox@Bri-C GmbH, Sarstedt
|
||||
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
// this software and associated documentation files (the "Software"), to deal in
|
||||
// the Software without restriction, including without limitation the rights to
|
||||
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
// the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
// subject to the following conditions:
|
||||
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
// 2019/08/25 code revision to reduce unsafe use
|
||||
// Parts are adopted from the fork at ipfs/bbloom after performance rev by
|
||||
// Steve Allen (https://github.com/Stebalien)
|
||||
// (see https://github.com/ipfs/bbloom/blob/master/bbloom.go)
|
||||
// -> func Has
|
||||
// -> func set
|
||||
// -> func add
|
||||
|
||||
package bbloom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// helper
|
||||
// not needed anymore by Set
|
||||
// var mask = []uint8{1, 2, 4, 8, 16, 32, 64, 128}
|
||||
|
||||
func getSize(ui64 uint64) (size uint64, exponent uint64) {
|
||||
if ui64 < uint64(512) {
|
||||
ui64 = uint64(512)
|
||||
}
|
||||
size = uint64(1)
|
||||
for size < ui64 {
|
||||
size <<= 1
|
||||
exponent++
|
||||
}
|
||||
return size, exponent
|
||||
}
|
||||
|
||||
func calcSizeByWrongPositives(numEntries, wrongs float64) (uint64, uint64) {
|
||||
size := -1 * numEntries * math.Log(wrongs) / math.Pow(float64(0.69314718056), 2)
|
||||
locs := math.Ceil(float64(0.69314718056) * size / numEntries)
|
||||
return uint64(size), uint64(locs)
|
||||
}
|
||||
|
||||
// New
|
||||
// returns a new bloomfilter
|
||||
func New(params ...float64) (bloomfilter Bloom) {
|
||||
var entries, locs uint64
|
||||
if len(params) == 2 {
|
||||
if params[1] < 1 {
|
||||
entries, locs = calcSizeByWrongPositives(params[0], params[1])
|
||||
} else {
|
||||
entries, locs = uint64(params[0]), uint64(params[1])
|
||||
}
|
||||
} else {
|
||||
log.Fatal("usage: New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(3)) or New(float64(number_of_entries), float64(number_of_hashlocations)) i.e. New(float64(1000), float64(0.03))")
|
||||
}
|
||||
size, exponent := getSize(uint64(entries))
|
||||
bloomfilter = Bloom{
|
||||
Mtx: &sync.Mutex{},
|
||||
sizeExp: exponent,
|
||||
size: size - 1,
|
||||
setLocs: locs,
|
||||
shift: 64 - exponent,
|
||||
}
|
||||
bloomfilter.Size(size)
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// NewWithBoolset
|
||||
// takes a []byte slice and number of locs per entry
|
||||
// returns the bloomfilter with a bitset populated according to the input []byte
|
||||
func NewWithBoolset(bs *[]byte, locs uint64) (bloomfilter Bloom) {
|
||||
bloomfilter = New(float64(len(*bs)<<3), float64(locs))
|
||||
for i, b := range *bs {
|
||||
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bloomfilter.bitset[0])) + uintptr(i))) = b
|
||||
}
|
||||
return bloomfilter
|
||||
}
|
||||
|
||||
// bloomJSONImExport
|
||||
// Im/Export structure used by JSONMarshal / JSONUnmarshal
|
||||
type bloomJSONImExport struct {
|
||||
FilterSet []byte
|
||||
SetLocs uint64
|
||||
}
|
||||
|
||||
// JSONUnmarshal
|
||||
// takes JSON-Object (type bloomJSONImExport) as []bytes
|
||||
// returns Bloom object
|
||||
func JSONUnmarshal(dbData []byte) Bloom {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
json.Unmarshal(dbData, &bloomImEx)
|
||||
buf := bytes.NewBuffer(bloomImEx.FilterSet)
|
||||
bs := buf.Bytes()
|
||||
bf := NewWithBoolset(&bs, bloomImEx.SetLocs)
|
||||
return bf
|
||||
}
|
||||
|
||||
//
|
||||
// Bloom filter
|
||||
type Bloom struct {
|
||||
Mtx *sync.Mutex
|
||||
ElemNum uint64
|
||||
bitset []uint64
|
||||
sizeExp uint64
|
||||
size uint64
|
||||
setLocs uint64
|
||||
shift uint64
|
||||
}
|
||||
|
||||
// <--- http://www.cse.yorku.ca/~oz/hash.html
|
||||
// modified Berkeley DB Hash (32bit)
|
||||
// hash is casted to l, h = 16bit fragments
|
||||
// func (bl Bloom) absdbm(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = uint64(c) + (hash << 6) + (hash << bl.sizeExp) - hash
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.shift >> bl.shift
|
||||
// return l, h
|
||||
// }
|
||||
|
||||
// Update: found sipHash of Jean-Philippe Aumasson & Daniel J. Bernstein to be even faster than absdbm()
|
||||
// https://131002.net/siphash/
|
||||
// siphash was implemented for Go by Dmitry Chestnykh https://github.com/dchest/siphash
|
||||
|
||||
// Add
|
||||
// set the bit(s) for entry; Adds an entry to the Bloom filter
|
||||
func (bl *Bloom) Add(entry []byte) {
|
||||
l, h := bl.sipHash(entry)
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
bl.set((h + i*l) & bl.size)
|
||||
bl.ElemNum++
|
||||
}
|
||||
}
|
||||
|
||||
// AddTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) AddTS(entry []byte) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
bl.Add(entry)
|
||||
}
|
||||
|
||||
// Has
|
||||
// check if bit(s) for entry is/are set
|
||||
// returns true if the entry was added to the Bloom Filter
|
||||
func (bl Bloom) Has(entry []byte) bool {
|
||||
l, h := bl.sipHash(entry)
|
||||
res := true
|
||||
for i := uint64(0); i < bl.setLocs; i++ {
|
||||
res = res && bl.isSet((h+i*l)&bl.size)
|
||||
// https://github.com/ipfs/bbloom/commit/84e8303a9bfb37b2658b85982921d15bbb0fecff
|
||||
// // Branching here (early escape) is not worth it
|
||||
// // This is my conclusion from benchmarks
|
||||
// // (prevents loop unrolling)
|
||||
// switch bl.IsSet((h + i*l) & bl.size) {
|
||||
// case false:
|
||||
// return false
|
||||
// }
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// HasTS
|
||||
// Thread safe: Mutex.Lock the bloomfilter for the time of processing the entry
|
||||
func (bl *Bloom) HasTS(entry []byte) bool {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.Has(entry)
|
||||
}
|
||||
|
||||
// AddIfNotHas
|
||||
// Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl Bloom) AddIfNotHas(entry []byte) (added bool) {
|
||||
if bl.Has(entry) {
|
||||
return added
|
||||
}
|
||||
bl.Add(entry)
|
||||
return true
|
||||
}
|
||||
|
||||
// AddIfNotHasTS
|
||||
// Tread safe: Only Add entry if it's not present in the bloomfilter
|
||||
// returns true if entry was added
|
||||
// returns false if entry was allready registered in the bloomfilter
|
||||
func (bl *Bloom) AddIfNotHasTS(entry []byte) (added bool) {
|
||||
bl.Mtx.Lock()
|
||||
defer bl.Mtx.Unlock()
|
||||
return bl.AddIfNotHas(entry)
|
||||
}
|
||||
|
||||
// Size
|
||||
// make Bloom filter with as bitset of size sz
|
||||
func (bl *Bloom) Size(sz uint64) {
|
||||
bl.bitset = make([]uint64, sz>>6)
|
||||
}
|
||||
|
||||
// Clear
|
||||
// resets the Bloom filter
|
||||
func (bl *Bloom) Clear() {
|
||||
bs := bl.bitset
|
||||
for i := range bs {
|
||||
bs[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
// set the bit[idx] of bitsit
|
||||
func (bl *Bloom) set(idx uint64) {
|
||||
// ommit unsafe
|
||||
// *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3))) |= mask[idx%8]
|
||||
bl.bitset[idx>>6] |= 1 << (idx % 64)
|
||||
}
|
||||
|
||||
// IsSet
|
||||
// check if bit[idx] of bitset is set
|
||||
// returns true/false
|
||||
func (bl *Bloom) isSet(idx uint64) bool {
|
||||
// ommit unsafe
|
||||
// return (((*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[idx>>6])) + uintptr((idx%64)>>3)))) >> (idx % 8)) & 1) == 1
|
||||
return bl.bitset[idx>>6]&(1<<(idx%64)) != 0
|
||||
}
|
||||
|
||||
// JSONMarshal
|
||||
// returns JSON-object (type bloomJSONImExport) as []byte
|
||||
func (bl Bloom) JSONMarshal() []byte {
|
||||
bloomImEx := bloomJSONImExport{}
|
||||
bloomImEx.SetLocs = uint64(bl.setLocs)
|
||||
bloomImEx.FilterSet = make([]byte, len(bl.bitset)<<3)
|
||||
for i := range bloomImEx.FilterSet {
|
||||
bloomImEx.FilterSet[i] = *(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&bl.bitset[0])) + uintptr(i)))
|
||||
}
|
||||
data, err := json.Marshal(bloomImEx)
|
||||
if err != nil {
|
||||
log.Fatal("json.Marshal failed: ", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// // alternative hashFn
|
||||
// func (bl Bloom) fnv64a(b *[]byte) (l, h uint64) {
|
||||
// h64 := fnv.New64a()
|
||||
// h64.Write(*b)
|
||||
// hash := h64.Sum64()
|
||||
// h = hash >> 32
|
||||
// l = hash << 32 >> 32
|
||||
// return l, h
|
||||
// }
|
||||
//
|
||||
// // <-- http://partow.net/programming/hashfunctions/index.html
|
||||
// // citation: An algorithm proposed by Donald E. Knuth in The Art Of Computer Programming Volume 3,
|
||||
// // under the topic of sorting and search chapter 6.4.
|
||||
// // modified to fit with boolset-length
|
||||
// func (bl Bloom) DEKHash(b *[]byte) (l, h uint64) {
|
||||
// hash := uint64(len(*b))
|
||||
// for _, c := range *b {
|
||||
// hash = ((hash << 5) ^ (hash >> bl.shift)) ^ uint64(c)
|
||||
// }
|
||||
// h = hash >> bl.shift
|
||||
// l = hash << bl.sizeExp >> bl.sizeExp
|
||||
// return l, h
|
||||
// }
|
||||
225
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
Normal file
225
backend/services/mochi/vendor/github.com/AndreasBriese/bbloom/sipHash.go
generated
vendored
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
// Written in 2012 by Dmitry Chestnykh.
|
||||
//
|
||||
// To the extent possible under law, the author have dedicated all copyright
|
||||
// and related and neighboring rights to this software to the public domain
|
||||
// worldwide. This software is distributed without any warranty.
|
||||
// http://creativecommons.org/publicdomain/zero/1.0/
|
||||
//
|
||||
// Package siphash implements SipHash-2-4, a fast short-input PRF
|
||||
// created by Jean-Philippe Aumasson and Daniel J. Bernstein.
|
||||
|
||||
package bbloom
|
||||
|
||||
// Hash returns the 64-bit SipHash-2-4 of the given byte slice with two 64-bit
|
||||
// parts of 128-bit key: k0 and k1.
|
||||
func (bl Bloom) sipHash(p []byte) (l, h uint64) {
|
||||
// Initialization.
|
||||
v0 := uint64(8317987320269560794) // k0 ^ 0x736f6d6570736575
|
||||
v1 := uint64(7237128889637516672) // k1 ^ 0x646f72616e646f6d
|
||||
v2 := uint64(7816392314733513934) // k0 ^ 0x6c7967656e657261
|
||||
v3 := uint64(8387220255325274014) // k1 ^ 0x7465646279746573
|
||||
t := uint64(len(p)) << 56
|
||||
|
||||
// Compression.
|
||||
for len(p) >= 8 {
|
||||
|
||||
m := uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
|
||||
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
|
||||
|
||||
v3 ^= m
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= m
|
||||
p = p[8:]
|
||||
}
|
||||
|
||||
// Compress last block.
|
||||
switch len(p) {
|
||||
case 7:
|
||||
t |= uint64(p[6]) << 48
|
||||
fallthrough
|
||||
case 6:
|
||||
t |= uint64(p[5]) << 40
|
||||
fallthrough
|
||||
case 5:
|
||||
t |= uint64(p[4]) << 32
|
||||
fallthrough
|
||||
case 4:
|
||||
t |= uint64(p[3]) << 24
|
||||
fallthrough
|
||||
case 3:
|
||||
t |= uint64(p[2]) << 16
|
||||
fallthrough
|
||||
case 2:
|
||||
t |= uint64(p[1]) << 8
|
||||
fallthrough
|
||||
case 1:
|
||||
t |= uint64(p[0])
|
||||
}
|
||||
|
||||
v3 ^= t
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
v0 ^= t
|
||||
|
||||
// Finalization.
|
||||
v2 ^= 0xff
|
||||
|
||||
// Round 1.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 2.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 3.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// Round 4.
|
||||
v0 += v1
|
||||
v1 = v1<<13 | v1>>51
|
||||
v1 ^= v0
|
||||
v0 = v0<<32 | v0>>32
|
||||
|
||||
v2 += v3
|
||||
v3 = v3<<16 | v3>>48
|
||||
v3 ^= v2
|
||||
|
||||
v0 += v3
|
||||
v3 = v3<<21 | v3>>43
|
||||
v3 ^= v0
|
||||
|
||||
v2 += v1
|
||||
v1 = v1<<17 | v1>>47
|
||||
v1 ^= v2
|
||||
v2 = v2<<32 | v2>>32
|
||||
|
||||
// return v0 ^ v1 ^ v2 ^ v3
|
||||
|
||||
hash := v0 ^ v1 ^ v2 ^ v3
|
||||
h = hash >> bl.shift
|
||||
l = hash << bl.shift >> bl.shift
|
||||
return l, h
|
||||
|
||||
}
|
||||
24
backend/services/mochi/vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
Normal file
24
backend/services/mochi/vendor/github.com/alicebob/gopher-json/LICENSE
generated
vendored
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
This is free and unencumbered software released into the public domain.
|
||||
|
||||
Anyone is free to copy, modify, publish, use, compile, sell, or
|
||||
distribute this software, either in source code form or as a compiled
|
||||
binary, for any purpose, commercial or non-commercial, and by any
|
||||
means.
|
||||
|
||||
In jurisdictions that recognize copyright laws, the author or authors
|
||||
of this software dedicate any and all copyright interest in the
|
||||
software to the public domain. We make this dedication for the benefit
|
||||
of the public at large and to the detriment of our heirs and
|
||||
successors. We intend this dedication to be an overt act of
|
||||
relinquishment in perpetuity of all present and future rights to this
|
||||
software under copyright law.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
For more information, please refer to <http://unlicense.org/>
|
||||
7
backend/services/mochi/vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
Normal file
7
backend/services/mochi/vendor/github.com/alicebob/gopher-json/README.md
generated
vendored
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# gopher-json [](https://godoc.org/layeh.com/gopher-json)
|
||||
|
||||
Package json is a simple JSON encoder/decoder for [gopher-lua](https://github.com/yuin/gopher-lua).
|
||||
|
||||
## License
|
||||
|
||||
Public domain
|
||||
33
backend/services/mochi/vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
Normal file
33
backend/services/mochi/vendor/github.com/alicebob/gopher-json/doc.go
generated
vendored
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// Package json is a simple JSON encoder/decoder for gopher-lua.
|
||||
//
|
||||
// Documentation
|
||||
//
|
||||
// The following functions are exposed by the library:
|
||||
// decode(string): Decodes a JSON string. Returns nil and an error string if
|
||||
// the string could not be decoded.
|
||||
// encode(value): Encodes a value into a JSON string. Returns nil and an error
|
||||
// string if the value could not be encoded.
|
||||
//
|
||||
// The following types are supported:
|
||||
//
|
||||
// Lua | JSON
|
||||
// ---------+-----
|
||||
// nil | null
|
||||
// number | number
|
||||
// string | string
|
||||
// table | object: when table is non-empty and has only string keys
|
||||
// | array: when table is empty, or has only sequential numeric keys
|
||||
// | starting from 1
|
||||
//
|
||||
// Attempting to encode any other Lua type will result in an error.
|
||||
//
|
||||
// Example
|
||||
//
|
||||
// Below is an example usage of the library:
|
||||
// import (
|
||||
// luajson "layeh.com/gopher-json"
|
||||
// )
|
||||
//
|
||||
// L := lua.NewState()
|
||||
// luajson.Preload(s)
|
||||
package json
|
||||
189
backend/services/mochi/vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
Normal file
189
backend/services/mochi/vendor/github.com/alicebob/gopher-json/json.go
generated
vendored
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// Preload adds json to the given Lua state's package.preload table. After it
|
||||
// has been preloaded, it can be loaded using require:
|
||||
//
|
||||
// local json = require("json")
|
||||
func Preload(L *lua.LState) {
|
||||
L.PreloadModule("json", Loader)
|
||||
}
|
||||
|
||||
// Loader is the module loader function.
|
||||
func Loader(L *lua.LState) int {
|
||||
t := L.NewTable()
|
||||
L.SetFuncs(t, api)
|
||||
L.Push(t)
|
||||
return 1
|
||||
}
|
||||
|
||||
var api = map[string]lua.LGFunction{
|
||||
"decode": apiDecode,
|
||||
"encode": apiEncode,
|
||||
}
|
||||
|
||||
func apiDecode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to decode"), 1)
|
||||
return 0
|
||||
}
|
||||
str := L.CheckString(1)
|
||||
|
||||
value, err := Decode(L, []byte(str))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(value)
|
||||
return 1
|
||||
}
|
||||
|
||||
func apiEncode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to encode"), 1)
|
||||
return 0
|
||||
}
|
||||
value := L.CheckAny(1)
|
||||
|
||||
data, err := Encode(value)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(lua.LString(string(data)))
|
||||
return 1
|
||||
}
|
||||
|
||||
var (
|
||||
errNested = errors.New("cannot encode recursively nested tables to JSON")
|
||||
errSparseArray = errors.New("cannot encode sparse array")
|
||||
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
|
||||
)
|
||||
|
||||
type invalidTypeError lua.LValueType
|
||||
|
||||
func (i invalidTypeError) Error() string {
|
||||
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
|
||||
}
|
||||
|
||||
// Encode returns the JSON encoding of value.
|
||||
func Encode(value lua.LValue) ([]byte, error) {
|
||||
return json.Marshal(jsonValue{
|
||||
LValue: value,
|
||||
visited: make(map[*lua.LTable]bool),
|
||||
})
|
||||
}
|
||||
|
||||
type jsonValue struct {
|
||||
lua.LValue
|
||||
visited map[*lua.LTable]bool
|
||||
}
|
||||
|
||||
func (j jsonValue) MarshalJSON() (data []byte, err error) {
|
||||
switch converted := j.LValue.(type) {
|
||||
case lua.LBool:
|
||||
data, err = json.Marshal(bool(converted))
|
||||
case lua.LNumber:
|
||||
data, err = json.Marshal(float64(converted))
|
||||
case *lua.LNilType:
|
||||
data = []byte(`null`)
|
||||
case lua.LString:
|
||||
data, err = json.Marshal(string(converted))
|
||||
case *lua.LTable:
|
||||
if j.visited[converted] {
|
||||
return nil, errNested
|
||||
}
|
||||
j.visited[converted] = true
|
||||
|
||||
key, value := converted.Next(lua.LNil)
|
||||
|
||||
switch key.Type() {
|
||||
case lua.LTNil: // empty table
|
||||
data = []byte(`[]`)
|
||||
case lua.LTNumber:
|
||||
arr := make([]jsonValue, 0, converted.Len())
|
||||
expectedKey := lua.LNumber(1)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTNumber {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
if expectedKey != key {
|
||||
err = errSparseArray
|
||||
return
|
||||
}
|
||||
arr = append(arr, jsonValue{value, j.visited})
|
||||
expectedKey++
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(arr)
|
||||
case lua.LTString:
|
||||
obj := make(map[string]jsonValue)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTString {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
obj[key.String()] = jsonValue{value, j.visited}
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(obj)
|
||||
default:
|
||||
err = errInvalidKeys
|
||||
}
|
||||
default:
|
||||
err = invalidTypeError(j.LValue.Type())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode converts the JSON encoded data to Lua values.
|
||||
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
|
||||
var value interface{}
|
||||
err := json.Unmarshal(data, &value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeValue(L, value), nil
|
||||
}
|
||||
|
||||
// DecodeValue converts the value to a Lua value.
|
||||
//
|
||||
// This function only converts values that the encoding/json package decodes to.
|
||||
// All other values will return lua.LNil.
|
||||
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
|
||||
switch converted := value.(type) {
|
||||
case bool:
|
||||
return lua.LBool(converted)
|
||||
case float64:
|
||||
return lua.LNumber(converted)
|
||||
case string:
|
||||
return lua.LString(converted)
|
||||
case json.Number:
|
||||
return lua.LString(converted)
|
||||
case []interface{}:
|
||||
arr := L.CreateTable(len(converted), 0)
|
||||
for _, item := range converted {
|
||||
arr.Append(DecodeValue(L, item))
|
||||
}
|
||||
return arr
|
||||
case map[string]interface{}:
|
||||
tbl := L.CreateTable(0, len(converted))
|
||||
for key, item := range converted {
|
||||
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
|
||||
}
|
||||
return tbl
|
||||
case nil:
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
return lua.LNil
|
||||
}
|
||||
6
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
Normal file
6
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/.gitignore
generated
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
/integration/redis_src/
|
||||
/integration/dump.rdb
|
||||
*.swp
|
||||
/integration/nodes.conf
|
||||
.idea/
|
||||
miniredis.iml
|
||||
225
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
Normal file
225
backend/services/mochi/vendor/github.com/alicebob/miniredis/v2/CHANGELOG.md
generated
vendored
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
## Changelog
|
||||
|
||||
|
||||
### v2.23.0
|
||||
|
||||
- basic INFO support (thanks @kirill-a-belov)
|
||||
- support COUNT in SSCAN (thanks @Abdi-dd)
|
||||
- test and support Go 1.19
|
||||
- support LPOS (thanks @ianstarz)
|
||||
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
|
||||
|
||||
|
||||
### v2.22.0
|
||||
|
||||
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
|
||||
- fix invalid resposne of COMMAND (thanks @zsh1995)
|
||||
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
|
||||
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
|
||||
|
||||
|
||||
### v2.21.0
|
||||
|
||||
- support for GETEX (thanks @dntj)
|
||||
- support for GT and LT in ZADD (thanks @lsgndln)
|
||||
- support for XAUTOCLAIM (thanks @randall-fulton)
|
||||
|
||||
|
||||
### v2.20.0
|
||||
|
||||
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
|
||||
|
||||
|
||||
### v2.19.0
|
||||
|
||||
- support for TYPE in SCAN (thanks @0xDiddi)
|
||||
- update BITPOS (thanks @dirkm)
|
||||
- fix a lua redis.call() return value (thanks @mpetronic)
|
||||
- update ZRANGE (thanks @valdemarpereira)
|
||||
|
||||
|
||||
### v2.18.0
|
||||
|
||||
- support for ZUNION (thanks @propan)
|
||||
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
|
||||
- support for LMOVE (thanks @btwear)
|
||||
|
||||
|
||||
### v2.17.0
|
||||
|
||||
- added miniredis.RunT(t)
|
||||
|
||||
|
||||
### v2.16.1
|
||||
|
||||
- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang)
|
||||
- fix exclusive ranges in XRANGE (thanks @joseotoro)
|
||||
|
||||
|
||||
### v2.16.0
|
||||
|
||||
- simplify some code (thanks @zonque)
|
||||
- support for EXAT/PXAT in SET
|
||||
- support for XTRIM (thanks @joseotoro)
|
||||
- support for ZRANDMEMBER
|
||||
- support for redis.log() in lua (thanks @dirkm)
|
||||
|
||||
|
||||
### v2.15.2
|
||||
|
||||
- Fix race condition in blocking code (thanks @zonque and @robx)
|
||||
- XREAD accepts '$' as ID (thanks @bradengroom)
|
||||
|
||||
|
||||
### v2.15.1
|
||||
|
||||
- EVAL should cache the script (thanks @guoshimin)
|
||||
|
||||
|
||||
### v2.15.0
|
||||
|
||||
- target redis 6.2 and added new args to various commands
|
||||
- support for all hyperlog commands (thanks @ilbaktin)
|
||||
- support for GETDEL (thanks @wszaranski)
|
||||
|
||||
|
||||
### v2.14.5
|
||||
|
||||
- added XPENDING
|
||||
- support for BLOCK option in XREAD and XREADGROUP
|
||||
|
||||
|
||||
### v2.14.4
|
||||
|
||||
- fix BITPOS error (thanks @xiaoyuzdy)
|
||||
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
|
||||
- fix empty EXEC return type (thanks @ashanbrown)
|
||||
- fix XDEL (thanks @svakili and @yvesf)
|
||||
- fix FLUSHALL for streams (thanks @svakili)
|
||||
|
||||
|
||||
### v2.14.3
|
||||
|
||||
- fix problem where Lua code didn't set the selected DB
|
||||
- update to redis 6.0.10 (thanks @lazappa)
|
||||
|
||||
|
||||
### v2.14.2
|
||||
|
||||
- update LUA dependency
|
||||
- deal with (p)unsubscribe when there are no channels
|
||||
|
||||
|
||||
### v2.14.1
|
||||
|
||||
- mod tidy
|
||||
|
||||
|
||||
### v2.14.0
|
||||
|
||||
- support for HELLO and the RESP3 protocol
|
||||
- KEEPTTL in SET (thanks @johnpena)
|
||||
|
||||
|
||||
### v2.13.3
|
||||
|
||||
- support Go 1.14 and 1.15
|
||||
- update the `Check...()` methods
|
||||
- support for XREAD (thanks @pieterlexis)
|
||||
|
||||
|
||||
### v2.13.2
|
||||
|
||||
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
|
||||
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
|
||||
- changed unit and integration tests to compare raw payloads, not parsed payloads
|
||||
- remove "redigo" dependency
|
||||
|
||||
|
||||
### v2.13.1
|
||||
|
||||
- added HSTRLEN
|
||||
- minimal support for ACL users in AUTH
|
||||
|
||||
|
||||
### v2.13.0
|
||||
|
||||
- added RunTLS(...)
|
||||
- added SetError(...)
|
||||
|
||||
|
||||
### v2.12.0
|
||||
|
||||
- redis 6
|
||||
- Lua json update (thanks @gsmith85)
|
||||
- CLUSTER commands (thanks @kratisto)
|
||||
- fix TOUCH
|
||||
- fix a shutdown race condition
|
||||
|
||||
|
||||
### v2.11.4
|
||||
|
||||
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
|
||||
|
||||
|
||||
### v2.11.3
|
||||
|
||||
- support for TOUCH (thanks @cleroux)
|
||||
- support for cluster and stream commands (thanks @kak-tus)
|
||||
|
||||
|
||||
### v2.11.2
|
||||
|
||||
- make sure Lua code is executed concurrently
|
||||
- add command GEORADIUSBYMEMBER (thanks @kyeett)
|
||||
|
||||
|
||||
### v2.11.1
|
||||
|
||||
- globals protection for Lua code (thanks @vk-outreach)
|
||||
- HSET update (thanks @carlgreen)
|
||||
- fix BLPOP block on shutdown (thanks @Asalle)
|
||||
|
||||
|
||||
### v2.11.0
|
||||
|
||||
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
|
||||
- added GEODIST
|
||||
- improved precision for geohashes, closer to what real redis does
|
||||
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
|
||||
|
||||
|
||||
### v2.10.1
|
||||
|
||||
- added m.Server()
|
||||
|
||||
|
||||
### v2.10.0
|
||||
|
||||
- added UNLINK
|
||||
- fix DEL zero-argument case
|
||||
- cleanup some direct access commands
|
||||
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
|
||||
|
||||
|
||||
### v2.9.1
|
||||
|
||||
- fix issue with ZRANGEBYLEX
|
||||
- fix issue with BRPOPLPUSH and direct access
|
||||
|
||||
|
||||
### v2.9.0
|
||||
|
||||
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
|
||||
- fix messages generated by PSUBSCRIBE
|
||||
- optional internal seed (thanks @zikaeroh)
|
||||
|
||||
|
||||
### v2.8.0
|
||||
|
||||
Proper `v2` in go.mod.
|
||||
|
||||
|
||||
### older
|
||||
|
||||
See https://github.com/alicebob/miniredis/releases for the full changelog
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user