diff --git a/backend/services/controller/.env b/backend/services/controller/.env index e23de3c..0cd3ef9 100644 --- a/backend/services/controller/.env +++ b/backend/services/controller/.env @@ -36,4 +36,5 @@ WS_TLS="" WS_AUTH="" WS_ROUTE="" WS_DISABLE="" +WS_SKIP_VERIFY="" # ---------------------------------------------------------------------------- # \ No newline at end of file diff --git a/backend/services/controller/cmd/oktopus/main.go b/backend/services/controller/cmd/oktopus/main.go index 7cb6ba9..37cd03c 100755 --- a/backend/services/controller/cmd/oktopus/main.go +++ b/backend/services/controller/cmd/oktopus/main.go @@ -77,6 +77,7 @@ func main() { flWsPort := flag.String("ws_port", lookupEnvOrString("WS_PORT", "8080"), "Websocket server port") flWsRoute := flag.String("ws_route", lookupEnvOrString("WS_ROUTE", "/ws/controller"), "Websocket server route") flWsTls := flag.Bool("ws_tls", lookupEnvOrBool("WS_TLS", false), "Websocket server tls") + flWsSkipVerify := flag.Bool("ws_skip_verify", lookupEnvOrBool("WS_SKIP_VERIFY", false), "Websocket skip tls certificate verify") flDisableWs := flag.Bool("ws_disable", lookupEnvOrBool("WS_DISABLE", false), "Disable WS MTP") flDisableStomp := flag.Bool("stomp_disable", lookupEnvOrBool("STOMP_DISABLE", false), "Disable STOMP MTP") flDisableMqtt := flag.Bool("mqtt_disable", lookupEnvOrBool("MQTT_DISABLE", false), "Disable MQTT MTP") @@ -169,14 +170,15 @@ func main() { go func() { wsClient = ws.Ws{ - Addr: *flWsAddr, - Port: *flWsPort, - Token: *flWsToken, - Route: *flWsRoute, - Auth: *flWsAuth, - TLS: *flWsTls, - DB: database, - Ctx: ctx, + Addr: *flWsAddr, + Port: *flWsPort, + Token: *flWsToken, + Route: *flWsRoute, + Auth: *flWsAuth, + TLS: *flWsTls, + InsecureSkipVerify: *flWsSkipVerify, + DB: database, + Ctx: ctx, } wsDone = make(chan os.Signal, 1) diff --git a/backend/services/controller/internal/ws/ws.go b/backend/services/controller/internal/ws/ws.go index b1e9fb2..e03c81f 100644 --- a/backend/services/controller/internal/ws/ws.go +++ b/backend/services/controller/internal/ws/ws.go @@ -2,6 +2,7 @@ package ws import ( "context" + "crypto/tls" "encoding/json" "log" "reflect" @@ -17,16 +18,17 @@ import ( ) type Ws struct { - Addr string - Port string - Token string - Route string - Auth bool - TLS bool - Ctx context.Context - NewDeviceQueue map[string]string - NewDevQMutex *sync.Mutex - DB db.Database + Addr string + Port string + Token string + Route string + Auth bool + TLS bool + InsecureSkipVerify bool + Ctx context.Context + NewDeviceQueue map[string]string + NewDevQMutex *sync.Mutex + DB db.Database } const ( @@ -47,9 +49,14 @@ type deviceStatus struct { var wsConn *websocket.Conn func (w *Ws) Connect() { + log.Println("Connecting to WS endpoint...") - // communication with devices - wsUrl := "ws://" + w.Addr + ":" + w.Port + w.Route + prefix := "ws://" + if w.TLS { + prefix = "wss://" + } + + wsUrl := prefix + w.Addr + ":" + w.Port + w.Route if w.Auth { log.Println("WS token:", w.Token) @@ -57,10 +64,16 @@ func (w *Ws) Connect() { wsUrl = wsUrl + "?token=" + w.Token } + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: w.InsecureSkipVerify, + }, + } + // Keeps trying to connect to the WS endpoint until it succeeds or receives a stop signal - go func() { + go func(dialer websocket.Dialer) { for { - c, _, err := websocket.DefaultDialer.Dial(wsUrl, nil) + c, _, err := dialer.Dial(wsUrl, nil) if err != nil { log.Printf("Error to connect to %s, err: %s", wsUrl, err) time.Sleep(WS_CONNECTION_RETRY) @@ -72,7 +85,7 @@ func (w *Ws) Connect() { go w.Subscribe() break } - }() + }(dialer) } func (w *Ws) Disconnect() { @@ -100,13 +113,18 @@ func (w *Ws) Subscribe() { w.NewDeviceQueue = make(map[string]string) for { - //TODO: deal with message in new go routine msgType, wsMsg, err := wsConn.ReadMessage() if err != nil { - log.Println("read:", err) + if websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("websocket error: %v", err) + w.Connect() + return + } + log.Println("websocket unexpected error:", err) return } + //TODO: deal with message in new go routine if msgType == websocket.TextMessage { var deviceStatus deviceStatus err = json.Unmarshal(wsMsg, &deviceStatus)