package serverevents import ( "fmt" di "git.apihub24.de/admin/generic-di" "github.com/gorilla/websocket" "net/http" "sync" "time" ) func init() { di.Injectable(newServerEventsMiddleware) } type MiddlewareOptions struct { Path string ContextLifetime time.Duration } type IMiddleware interface { Use(options MiddlewareOptions, mux *http.ServeMux) } type serverEventsMiddleware struct { options MiddlewareOptions upgrader websocket.Upgrader streamSubscribers map[string]bool emitter IEventEmitter parser IMessageParser registration IEventHandlerRegistration mutex sync.Mutex } func newServerEventsMiddleware() IMiddleware { return &serverEventsMiddleware{ upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, streamSubscribers: make(map[string]bool), emitter: di.Inject[IEventEmitter](), parser: di.Inject[IMessageParser](), registration: di.Inject[IEventHandlerRegistration](), mutex: sync.Mutex{}, } } func (middleware *serverEventsMiddleware) Use(options MiddlewareOptions, muxer *http.ServeMux) { middleware.options = options middleware.selectMethod(options.Path, func(w http.ResponseWriter, r *http.Request) { middleware.handleEventStream(w, r) }, muxer) } func (middleware *serverEventsMiddleware) handleEventStream(w http.ResponseWriter, r *http.Request) { conn, err := middleware.upgrader.Upgrade(w, r, nil) if err != nil { fmt.Println("Error upgrading:", err) return } defer func() { // do not handle Error _ = conn.Close() }() id := r.URL.Query().Get("id") context := di.Inject[IContext](id) context.SetId(id) // Locks the mutex before accessing 'middleware.streamSubscribers', // to prevent race conditions when reading/writing to the map. // This ensures that only one goroutine checks/registers at a time. // defer can't be used here as the stream listens for messages in an infinite loop middleware.mutex.Lock() if _, ok := middleware.streamSubscribers[id]; !ok { subscription := middleware.emitter.OnAll(func(ev Event) { // trigger the Backend Handlers handler, getHandlerErr := middleware.registration.GetHandler(ev.Type) if getHandlerErr != nil { println(fmt.Sprintf("no Handler found for Event %s", ev.Type)) } else { if handler.CanExecute(context) { handler.Handle(context) } } if ev.IsBackendOnly || ev.Filter == nil || !ev.Filter(context) { // the Event is Backend only or there is no socket (context) Filter // to send to all Sockets deliver context Filter with returns bool return } jsonData, jsonErr := middleware.parser.ToString(ev) if jsonErr != nil { println(fmt.Sprintf("Error parse event %s %s", jsonErr.Error(), ev.Type)) return } _ = conn.WriteMessage(websocket.TextMessage, []byte(jsonData)) }) defer func() { // Blocks the mutex again, as this block also accesses 'middleware.streamSubscribers'. // This is crucial to avoid race conditions when removing entries, // while new connections may be established. // defer cannot be used here either middleware.mutex.Lock() subscription.Unsubscribe() delete(middleware.streamSubscribers, id) // starts the Cleanup Process context.CleanupIn(middleware.options.ContextLifetime) middleware.mutex.Unlock() }() middleware.streamSubscribers[id] = true } middleware.mutex.Unlock() ctx := r.Context() for { messageType, data, err := conn.ReadMessage() if err != nil { return } if middleware.HandleMessage(conn, messageType, data) { return } select { case <-ctx.Done(): return } } } func (middleware *serverEventsMiddleware) selectMethod(path string, todo http.HandlerFunc, muxer *http.ServeMux) { if muxer == nil { http.HandleFunc(path, todo) } else { muxer.HandleFunc(path, todo) } } func (middleware *serverEventsMiddleware) HandleMessage(conn *websocket.Conn, messageType int, data []byte) bool { switch messageType { case websocket.PingMessage: pongErr := conn.WriteMessage(websocket.PongMessage, []byte{}) if pongErr != nil { println(fmt.Sprintf("error on send PongMessage: %s", pongErr.Error())) } break case websocket.CloseMessage: // return true to close the Websocket return true case websocket.TextMessage: ev, parseErr := middleware.parser.FromString(string(data)) if parseErr != nil { println(fmt.Sprintf("error on parse Event: %s data: %s", parseErr.Error(), string(data))) return false } // Event was dispatched over the Websocket so not send it back to Client! ev.IsBackendOnly = true middleware.emitter.Emit(ev) break case websocket.BinaryMessage: println(fmt.Sprintf("BinaryMessages are not supported")) break } return false }