server_events/middleware.go
2025-06-23 21:03:52 +02:00

125 lines
2.8 KiB
Go

package serverevents
import (
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/websocket"
)
var streamSubscribers = make(map[string]bool)
var eventHandlerStore = make(map[string]EventHandler)
var effectStore = make(map[string]Effect)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func RegisterEvents(events []EventHandler, effects []Effect, muxer *http.ServeMux) {
for _, event := range events {
eventHandlerStore[event.GetEventName()] = event
}
for _, effect := range effects {
effectStore[effect.OnEvent()] = effect
}
registerRoute(muxer)
registerSender(muxer)
}
func registerSender(muxer *http.ServeMux) {
emitter := GetEventEmitter()
selectMethod("/event/stream", func(w http.ResponseWriter, r *http.Request) {
handleEventStream(w, r, emitter)
}, muxer)
}
func handleEventStream(w http.ResponseWriter, r *http.Request, emitter *EventEmitter) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println("Error upgrading:", err)
return
}
defer conn.Close()
id := r.URL.Query().Get("id")
if _, ok := streamSubscribers[id]; !ok {
subscription := emitter.OnAll(func(event Event) {
triggerEffect(event.Type, NewServerEventContext(w, r, &event))
if !event.IsBackendOnly {
jsonData, jsonErr := json.Marshal(event)
if jsonErr != nil {
fmt.Println("Error make json string", jsonErr)
return
}
_ = conn.WriteMessage(websocket.TextMessage, jsonData)
}
})
defer func() {
subscription.Unsubscribe()
delete(streamSubscribers, id)
}()
streamSubscribers[id] = true
}
ctx := r.Context()
for {
_, _, err := conn.ReadMessage()
if err != nil {
return
}
select {
case <-ctx.Done():
return
}
}
}
func registerRoute(muxer *http.ServeMux) {
selectMethod("/event/{event}", func(w http.ResponseWriter, r *http.Request) {
handleEvent(w, r)
}, muxer)
}
func handleEvent(w http.ResponseWriter, r *http.Request) {
context := NewServerEventContext(w, r, nil)
handler := eventHandlerStore[context.GetEventName()]
if handler == nil {
return
}
if !validateRequest(context) {
return
}
if !handler.CanExecute(context) {
handler.SendNotAuthorizedResponse(context)
return
}
handler.Handle(context)
triggerEffect(context.GetEventName(), context)
}
func validateRequest(context *Context) bool {
canExecute := true
if context.GetMethod() != "POST" {
canExecute = false
context.Send(http.StatusMethodNotAllowed, "text/plain", []byte("only POST Method allowed!"))
}
return canExecute
}
func selectMethod(path string, todo http.HandlerFunc, muxer *http.ServeMux) {
if muxer == nil {
http.HandleFunc(path, todo)
} else {
muxer.HandleFunc(path, todo)
}
}
func triggerEffect(eventName string, context *Context) {
effect := effectStore[eventName]
if effect != nil {
effect.Execute(context)
}
}