Переглянути джерело

fix: restart remote xray after disabling a client to kill active sessions (#4918)

* fix(node-traffic): restart remote xray after disabling clients to kill active sessions

When a client's traffic limit is reached on a remote node, the panel pushes
enable=false to that node via UpdateInbound. The node calls RemoveUser on its
local xray, which blocks new connections but leaves any already-established TCP
session alive. The user could continue browsing/downloading until they
disconnected voluntarily.

Fix: after successfully pushing a client disable to a remote node, call
RestartXray on that node. This mirrors what already happens for the local node
when the "Restart Xray on client disable" setting is enabled (default: on),
and ensures active sessions are terminated immediately on all nodes where the
client was disabled.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>

* refactor(node): restart remote xray after tx commit, not inside it

Move the remote RestartXray calls out of the addTraffic write
transaction. disableInvalidClients now returns the affected remote
node IDs instead of restarting their xray while the SQLite write lock
is held; AddTraffic performs the restart after the transaction commits
via restartRemoteNodesOnDisable. Avoids holding the serialized write
lock across slow per-node restart RPCs.

---------

Co-authored-by: Claude Sonnet 4.6 <[email protected]>
Co-authored-by: Sanaei <[email protected]>
Hamed 7 годин тому
батько
коміт
d6d2085d60
1 змінених файлів з 50 додано та 12 видалено
  1. 50 12
      web/service/inbound.go

+ 50 - 12
web/service/inbound.go

@@ -1789,15 +1789,19 @@ func (s *InboundService) setRemoteTrafficLocked(nodeID int, snap *runtime.Traffi
 }
 
 func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (needRestart bool, clientsDisabled bool, err error) {
+	var disabledNodeIDs []int
 	err = submitTrafficWrite(func() error {
 		var inner error
-		needRestart, clientsDisabled, inner = s.addTrafficLocked(inboundTraffics, clientTraffics)
+		needRestart, clientsDisabled, disabledNodeIDs, inner = s.addTrafficLocked(inboundTraffics, clientTraffics)
 		return inner
 	})
+	if err == nil && len(disabledNodeIDs) > 0 {
+		s.restartRemoteNodesOnDisable(disabledNodeIDs)
+	}
 	return
 }
 
-func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (bool, bool, error) {
+func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (bool, bool, []int, error) {
 	var err error
 	db := database.GetDB()
 	tx := db.Begin()
@@ -1811,11 +1815,11 @@ func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clien
 	}()
 	err = s.addInboundTraffic(tx, inboundTraffics)
 	if err != nil {
-		return false, false, err
+		return false, false, nil, err
 	}
 	err = s.addClientTraffic(tx, clientTraffics)
 	if err != nil {
-		return false, false, err
+		return false, false, nil, err
 	}
 
 	needRestart0, count, err := s.autoRenewClients(tx)
@@ -1826,7 +1830,7 @@ func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clien
 	}
 
 	disabledClientsCount := int64(0)
-	needRestart1, count, err := s.disableInvalidClients(tx)
+	needRestart1, count, disabledNodeIDs, err := s.disableInvalidClients(tx)
 	if err != nil {
 		logger.Warning("Error in disabling invalid clients:", err)
 	} else if count > 0 {
@@ -1840,7 +1844,7 @@ func (s *InboundService) addTrafficLocked(inboundTraffics []*xray.Traffic, clien
 	} else if count > 0 {
 		logger.Debugf("%v inbounds disabled", count)
 	}
-	return needRestart0 || needRestart1 || needRestart2, disabledClientsCount > 0, nil
+	return needRestart0 || needRestart1 || needRestart2, disabledClientsCount > 0, disabledNodeIDs, nil
 }
 
 func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic) error {
@@ -2196,7 +2200,7 @@ func (s *InboundService) disableInvalidInbounds(tx *gorm.DB) (bool, int64, error
 	return needRestart, count, err
 }
 
-func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) {
+func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, []int, error) {
 	now := time.Now().Unix() * 1000
 	needRestart := false
 
@@ -2205,10 +2209,10 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error)
 		Where("((total > 0 AND up + down >= total) OR (expiry_time > 0 AND expiry_time <= ?)) AND enable = ?", now, true).
 		Find(&depletedRows).Error
 	if err != nil {
-		return false, 0, err
+		return false, 0, nil, err
 	}
 	if len(depletedRows) == 0 {
-		return false, 0, nil
+		return false, 0, nil, nil
 	}
 
 	depletedEmails := make([]string, 0, len(depletedRows))
@@ -2236,7 +2240,7 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error)
 			WHERE clients.email IN ?
 		`, depletedEmails).Scan(&targets).Error
 		if err != nil {
-			return false, 0, err
+			return false, 0, nil, err
 		}
 	}
 
@@ -2283,7 +2287,7 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error)
 	err = result.Error
 	count := result.RowsAffected
 	if err != nil {
-		return needRestart, count, err
+		return needRestart, count, nil, err
 	}
 
 	if len(depletedEmails) > 0 {
@@ -2294,6 +2298,7 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error)
 		}
 	}
 
+	disabledNodeIDs := make(map[int]struct{})
 	for inboundID, group := range remoteByInbound {
 		emails := make(map[string]struct{}, len(group))
 		for _, t := range group {
@@ -2302,10 +2307,43 @@ func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error)
 		if pushErr := s.disableRemoteClients(tx, inboundID, emails); pushErr != nil {
 			logger.Warning("disableInvalidClients: push to remote failed for inbound", inboundID, ":", pushErr)
 			needRestart = true
+		} else {
+			for _, t := range group {
+				if t.NodeID != nil {
+					disabledNodeIDs[*t.NodeID] = struct{}{}
+				}
+			}
 		}
 	}
 
-	return needRestart, count, nil
+	nodeIDs := make([]int, 0, len(disabledNodeIDs))
+	for nodeID := range disabledNodeIDs {
+		nodeIDs = append(nodeIDs, nodeID)
+	}
+
+	return needRestart, count, nodeIDs, nil
+}
+
+func (s *InboundService) restartRemoteNodesOnDisable(nodeIDs []int) {
+	restartOnDisable, err := (&SettingService{}).GetRestartXrayOnClientDisable()
+	if err != nil {
+		logger.Warning("disableInvalidClients: get RestartXrayOnClientDisable failed:", err)
+		return
+	}
+	if !restartOnDisable {
+		return
+	}
+	for _, nodeID := range nodeIDs {
+		nodeIDCopy := nodeID
+		rt, rtErr := runtime.GetManager().RuntimeFor(&nodeIDCopy)
+		if rtErr != nil {
+			logger.Warning("disableInvalidClients: get runtime for node", nodeID, "failed:", rtErr)
+			continue
+		}
+		if rtErr = rt.RestartXray(context.Background()); rtErr != nil {
+			logger.Warning("disableInvalidClients: restart xray on node", nodeID, "failed:", rtErr)
+		}
+	}
 }
 
 // markClientsDisabledInSettings flips client.enable=false in the inbound's