hub.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. // Package websocket provides a WebSocket hub for real-time updates and notifications.
  2. package websocket
  3. import (
  4. "context"
  5. "encoding/json"
  6. "sync"
  7. "time"
  8. "github.com/mhsanaei/3x-ui/v3/logger"
  9. )
  10. // MessageType identifies the kind of WebSocket message.
  11. type MessageType string
  12. const (
  13. MessageTypeStatus MessageType = "status"
  14. MessageTypeTraffic MessageType = "traffic"
  15. MessageTypeInbounds MessageType = "inbounds"
  16. MessageTypeOutbounds MessageType = "outbounds"
  17. MessageTypeNodes MessageType = "nodes"
  18. MessageTypeNotification MessageType = "notification"
  19. MessageTypeXrayState MessageType = "xray_state"
  20. MessageTypeClientStats MessageType = "client_stats"
  21. MessageTypeClients MessageType = "clients"
  22. MessageTypeInvalidate MessageType = "invalidate"
  23. maxMessageSize = 10 * 1024 * 1024 // 10MB
  24. enqueueTimeout = 100 * time.Millisecond
  25. clientSendQueue = 512 // ~50s of buffering for a momentarily slow browser.
  26. hubBroadcastQueue = 2048 // Headroom for cron-storm + admin-mutation bursts.
  27. hubOpsQueue = 128 // Backlog for register+unregister bursts (page reloads, disconnect storms).
  28. minBroadcastInterval = 250 * time.Millisecond
  29. hubRestartAttempts = 3
  30. )
  31. type clientOpKind int
  32. const (
  33. opRegister clientOpKind = iota
  34. opUnregister
  35. )
  36. type clientOp struct {
  37. kind clientOpKind
  38. c *Client
  39. }
  40. // NewClient builds a Client ready for hub registration.
  41. func NewClient(id string) *Client {
  42. return &Client{
  43. ID: id,
  44. Send: make(chan []byte, clientSendQueue),
  45. }
  46. }
  47. // Message is the wire format sent to clients.
  48. type Message struct {
  49. Type MessageType `json:"type"`
  50. Payload any `json:"payload"`
  51. Time int64 `json:"time"`
  52. }
  53. // Client represents a single WebSocket connection.
  54. type Client struct {
  55. ID string
  56. Send chan []byte
  57. closeOnce sync.Once
  58. }
  59. // Hub fan-outs messages to all connected clients.
  60. type Hub struct {
  61. clients map[*Client]struct{}
  62. broadcast chan []byte
  63. ops chan clientOp
  64. mu sync.RWMutex
  65. ctx context.Context
  66. cancel context.CancelFunc
  67. throttleMu sync.Mutex
  68. lastBroadcast map[MessageType]time.Time
  69. }
  70. // NewHub creates a hub. Call Run in a goroutine to start its event loop.
  71. func NewHub() *Hub {
  72. ctx, cancel := context.WithCancel(context.Background())
  73. return &Hub{
  74. clients: make(map[*Client]struct{}),
  75. broadcast: make(chan []byte, hubBroadcastQueue),
  76. ops: make(chan clientOp, hubOpsQueue),
  77. ctx: ctx,
  78. cancel: cancel,
  79. lastBroadcast: make(map[MessageType]time.Time),
  80. }
  81. }
  82. var throttledMessageTypes = map[MessageType]struct{}{
  83. MessageTypeInbounds: {},
  84. MessageTypeOutbounds: {},
  85. MessageTypeTraffic: {},
  86. MessageTypeClientStats: {},
  87. }
  88. func (h *Hub) shouldThrottle(msgType MessageType) bool {
  89. if _, gated := throttledMessageTypes[msgType]; !gated {
  90. return false
  91. }
  92. h.throttleMu.Lock()
  93. defer h.throttleMu.Unlock()
  94. now := time.Now()
  95. if last, ok := h.lastBroadcast[msgType]; ok && now.Sub(last) < minBroadcastInterval {
  96. return true
  97. }
  98. h.lastBroadcast[msgType] = now
  99. return false
  100. }
  101. // Run drives the hub. The inner loop is wrapped in a panic-recovery harness
  102. // that retries up to hubRestartAttempts times with backoff so a transient
  103. // panic doesn't permanently kill real-time updates for commercial deployments.
  104. // After the cap, the hub stays down and the frontend falls back to REST polling.
  105. func (h *Hub) Run() {
  106. for attempt := range hubRestartAttempts {
  107. stopped := h.runOnce()
  108. if stopped {
  109. return
  110. }
  111. if attempt < hubRestartAttempts-1 {
  112. wait := time.Duration(1<<attempt) * time.Second // 1s, 2s, 4s
  113. logger.Errorf("WebSocket hub crashed, restarting in %s (%d/%d)", wait, attempt+1, hubRestartAttempts-1)
  114. select {
  115. case <-time.After(wait):
  116. case <-h.ctx.Done():
  117. return
  118. }
  119. }
  120. }
  121. logger.Error("WebSocket hub stopped after exhausting restart attempts")
  122. }
  123. // runOnce drives the event loop once and returns true if the hub stopped
  124. // cleanly (context cancelled). On panic, recover logs and returns false so
  125. // Run can decide whether to retry.
  126. func (h *Hub) runOnce() (stopped bool) {
  127. defer func() {
  128. if r := recover(); r != nil {
  129. logger.Errorf("WebSocket hub panic recovered: %v", r)
  130. stopped = false
  131. }
  132. }()
  133. for {
  134. select {
  135. case <-h.ctx.Done():
  136. h.shutdown()
  137. return true
  138. case op := <-h.ops:
  139. if op.c == nil {
  140. continue
  141. }
  142. switch op.kind {
  143. case opRegister:
  144. h.mu.Lock()
  145. h.clients[op.c] = struct{}{}
  146. n := len(h.clients)
  147. h.mu.Unlock()
  148. logger.Debugf("WebSocket client connected: %s (total: %d)", op.c.ID, n)
  149. case opUnregister:
  150. h.removeClient(op.c)
  151. }
  152. case msg := <-h.broadcast:
  153. h.fanout(msg)
  154. }
  155. }
  156. }
  157. // shutdown closes all client send channels and clears the registry.
  158. func (h *Hub) shutdown() {
  159. h.mu.Lock()
  160. for c := range h.clients {
  161. c.closeOnce.Do(func() { close(c.Send) })
  162. }
  163. h.clients = make(map[*Client]struct{})
  164. h.mu.Unlock()
  165. logger.Info("WebSocket hub stopped")
  166. }
  167. // removeClient deletes a client and closes its send channel exactly once.
  168. func (h *Hub) removeClient(c *Client) {
  169. h.mu.Lock()
  170. if _, ok := h.clients[c]; ok {
  171. delete(h.clients, c)
  172. c.closeOnce.Do(func() { close(c.Send) })
  173. }
  174. n := len(h.clients)
  175. h.mu.Unlock()
  176. logger.Debugf("WebSocket client disconnected: %s (total: %d)", c.ID, n)
  177. }
  178. // fanout delivers msg to every client. Each send is non-blocking — a client
  179. // whose buffer is full is collected for direct removal at the end. We do NOT
  180. // route slow-client unregistrations through the unregister channel: under
  181. // burst load (panel restart, network blip) that channel can fill up while the
  182. // hub itself is the consumer, causing a self-deadlock.
  183. func (h *Hub) fanout(msg []byte) {
  184. if msg == nil {
  185. return
  186. }
  187. h.mu.RLock()
  188. if len(h.clients) == 0 {
  189. h.mu.RUnlock()
  190. return
  191. }
  192. targets := make([]*Client, 0, len(h.clients))
  193. for c := range h.clients {
  194. targets = append(targets, c)
  195. }
  196. h.mu.RUnlock()
  197. var dead []*Client
  198. for _, c := range targets {
  199. if !trySend(c, msg) {
  200. dead = append(dead, c)
  201. }
  202. }
  203. if len(dead) == 0 {
  204. return
  205. }
  206. h.mu.Lock()
  207. for _, c := range dead {
  208. if _, ok := h.clients[c]; ok {
  209. delete(h.clients, c)
  210. c.closeOnce.Do(func() { close(c.Send) })
  211. logger.Debugf("WebSocket client %s send buffer full, disconnected", c.ID)
  212. }
  213. }
  214. h.mu.Unlock()
  215. }
  216. // trySend performs a non-blocking write to the client's Send channel.
  217. // Returns false if the client should be evicted (full buffer or closed channel).
  218. // A defer-recover guards against the rare race where the channel was closed
  219. // concurrently — sending on a closed channel always panics, even with select+default.
  220. func trySend(c *Client, msg []byte) (ok bool) {
  221. defer func() {
  222. if r := recover(); r != nil {
  223. ok = false
  224. }
  225. }()
  226. select {
  227. case c.Send <- msg:
  228. return true
  229. default:
  230. return false
  231. }
  232. }
  233. // Broadcast serializes payload and queues it for delivery to all clients.
  234. // If the serialized message exceeds maxMessageSize, an invalidate signal is
  235. // queued instead so the frontend re-fetches via REST. Broadcasts of throttled
  236. // message types (see throttledMessageTypes) within minBroadcastInterval of
  237. // the previous one are dropped — the next legitimate mutation will push the
  238. // fresh state.
  239. func (h *Hub) Broadcast(messageType MessageType, payload any) {
  240. if h == nil || payload == nil || h.GetClientCount() == 0 {
  241. return
  242. }
  243. if h.shouldThrottle(messageType) {
  244. return
  245. }
  246. data, err := json.Marshal(Message{
  247. Type: messageType,
  248. Payload: payload,
  249. Time: time.Now().UnixMilli(),
  250. })
  251. if err != nil {
  252. logger.Error("WebSocket marshal failed:", err)
  253. return
  254. }
  255. if len(data) > maxMessageSize {
  256. logger.Debugf("WebSocket payload %d bytes exceeds limit, sending invalidate for %s", len(data), messageType)
  257. h.broadcastInvalidate(messageType)
  258. return
  259. }
  260. h.enqueue(data)
  261. }
  262. // broadcastInvalidate queues a lightweight signal telling clients to re-fetch
  263. // the named data type via REST.
  264. func (h *Hub) broadcastInvalidate(originalType MessageType) {
  265. data, err := json.Marshal(Message{
  266. Type: MessageTypeInvalidate,
  267. Payload: map[string]string{"type": string(originalType)},
  268. Time: time.Now().UnixMilli(),
  269. })
  270. if err != nil {
  271. logger.Error("WebSocket invalidate marshal failed:", err)
  272. return
  273. }
  274. h.enqueue(data)
  275. }
  276. // enqueue submits raw bytes to the broadcast channel. Dropped on backpressure
  277. // (channel full for >100ms) or shutdown.
  278. func (h *Hub) enqueue(data []byte) {
  279. select {
  280. case h.broadcast <- data:
  281. case <-time.After(enqueueTimeout):
  282. logger.Warning("WebSocket broadcast channel full, dropping message")
  283. case <-h.ctx.Done():
  284. }
  285. }
  286. // GetClientCount returns the number of connected clients.
  287. func (h *Hub) GetClientCount() int {
  288. if h == nil {
  289. return 0
  290. }
  291. h.mu.RLock()
  292. defer h.mu.RUnlock()
  293. return len(h.clients)
  294. }
  295. // Register adds a client to the hub.
  296. func (h *Hub) Register(c *Client) {
  297. if h == nil || c == nil {
  298. return
  299. }
  300. select {
  301. case h.ops <- clientOp{kind: opRegister, c: c}:
  302. case <-h.ctx.Done():
  303. }
  304. }
  305. // Unregister removes a client from the hub. Sends through the same ordered
  306. // ops channel as Register so a register-then-unregister sequence from one
  307. // goroutine is processed in program order — otherwise an unregister could
  308. // land in the map before its register and silently no-op, leaking the entry.
  309. //
  310. // On a saturated ops channel (disconnect storm) we fall back to a bounded
  311. // timeout drop rather than direct removal: a direct delete on a not-yet-
  312. // registered client is precisely the ordering bug we fix here. Stragglers
  313. // get evicted by fanout when their Send buffer fills.
  314. func (h *Hub) Unregister(c *Client) {
  315. if h == nil || c == nil {
  316. return
  317. }
  318. select {
  319. case h.ops <- clientOp{kind: opUnregister, c: c}:
  320. case <-time.After(enqueueTimeout):
  321. logger.Warningf("WebSocket ops channel full, dropping unregister for %s", c.ID)
  322. case <-h.ctx.Done():
  323. }
  324. }
  325. // Stop signals the hub to shut down and close all client connections.
  326. func (h *Hub) Stop() {
  327. if h != nil && h.cancel != nil {
  328. h.cancel()
  329. }
  330. }