Browse Source

[db] Enbancement add traffic fully transactional

Co-Authored-By: Alireza Ahmadi <[email protected]>
MHSanaei 1 year ago
parent
commit
1277285d08
5 changed files with 64 additions and 116 deletions
  1. 0 37
      web/job/check_inbound_job.go
  2. 3 5
      web/job/xray_traffic_job.go
  3. 60 70
      web/service/inbound.go
  4. 1 1
      web/service/xray.go
  5. 0 3
      web/web.go

+ 0 - 37
web/job/check_inbound_job.go

@@ -1,37 +0,0 @@
-package job
-
-import (
-	"x-ui/logger"
-	"x-ui/web/service"
-)
-
-type CheckInboundJob struct {
-	xrayService    service.XrayService
-	inboundService service.InboundService
-}
-
-func NewCheckInboundJob() *CheckInboundJob {
-	return new(CheckInboundJob)
-}
-
-func (j *CheckInboundJob) Run() {
-	needRestart, count, err := j.inboundService.DisableInvalidClients()
-	if err != nil {
-		logger.Warning("Error in disabling invalid clients:", err)
-	} else if count > 0 {
-		logger.Debugf("%v clients disabled", count)
-		if needRestart {
-			j.xrayService.SetToNeedRestart()
-		}
-	}
-
-	needRestart, count, err = j.inboundService.DisableInvalidInbounds()
-	if err != nil {
-		logger.Warning("Error in disabling invalid inbounds:", err)
-	} else if count > 0 {
-		logger.Debugf("%v inbounds disabled", count)
-		if needRestart {
-			j.xrayService.SetToNeedRestart()
-		}
-	}
-}

+ 3 - 5
web/job/xray_traffic_job.go

@@ -24,14 +24,12 @@ func (j *XrayTrafficJob) Run() {
 		logger.Warning("get xray traffic failed:", err)
 		return
 	}
-	err = j.inboundService.AddTraffic(traffics)
+	err, needRestart := j.inboundService.AddTraffic(traffics, clientTraffics)
 	if err != nil {
 		logger.Warning("add traffic failed:", err)
 	}
-
-	err = j.inboundService.AddClientTraffic(clientTraffics)
-	if err != nil {
-		logger.Warning("add client traffic failed:", err)
+	if needRestart {
+		j.xrayService.SetToNeedRestart()
 	}
 
 }

+ 60 - 70
web/service/inbound.go

@@ -194,38 +194,6 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
 	return inbound, needRestart, err
 }
 
-func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error {
-	for _, inbound := range inbounds {
-		exist, err := s.checkPortExist(inbound.Port, 0)
-		if err != nil {
-			return err
-		}
-		if exist {
-			return common.NewError("Port already exists:", inbound.Port)
-		}
-	}
-
-	db := database.GetDB()
-	tx := db.Begin()
-	var err error
-	defer func() {
-		if err == nil {
-			tx.Commit()
-		} else {
-			tx.Rollback()
-		}
-	}()
-
-	for _, inbound := range inbounds {
-		err = tx.Save(inbound).Error
-		if err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
 func (s *InboundService) DelInbound(id int) (bool, error) {
 	db := database.GetDB()
 
@@ -687,35 +655,8 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin
 	return needRestart, tx.Save(oldInbound).Error
 }
 
-func (s *InboundService) AddTraffic(traffics []*xray.Traffic) error {
-	if len(traffics) == 0 {
-		return nil
-	}
-	// Update traffics in a single transaction
-	err := database.GetDB().Transaction(func(tx *gorm.DB) error {
-		for _, traffic := range traffics {
-			if traffic.IsInbound {
-				update := tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
-					Updates(map[string]interface{}{
-						"up":   gorm.Expr("up + ?", traffic.Up),
-						"down": gorm.Expr("down + ?", traffic.Down),
-					})
-				if update.Error != nil {
-					return update.Error
-				}
-			}
-		}
-		return nil
-	})
-
-	return err
-}
-
-func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err error) {
-	if len(traffics) == 0 {
-		return nil
-	}
-
+func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
+	var err error
 	db := database.GetDB()
 	tx := db.Begin()
 
@@ -726,13 +667,64 @@ func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err e
 			tx.Commit()
 		}
 	}()
+	err = s.addInboundTraffic(tx, inboundTraffics)
+	if err != nil {
+		return err, false
+	}
+	err = s.addClientTraffic(tx, clientTraffics)
+	if err != nil {
+		return err, false
+	}
+
+	needRestart1, count, err := s.disableInvalidClients(tx)
+	if err != nil {
+		logger.Warning("Error in disabling invalid clients:", err)
+	} else if count > 0 {
+		logger.Debugf("%v clients disabled", count)
+	}
+
+	needRestart2, count, err := s.disableInvalidInbounds(tx)
+	if err != nil {
+		logger.Warning("Error in disabling invalid inbounds:", err)
+	} else if count > 0 {
+		logger.Debugf("%v inbounds disabled", count)
+	}
+	return nil, (needRestart1 || needRestart2)
+}
+
+func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic) error {
+	if len(traffics) == 0 {
+		return nil
+	}
+
+	var err error
+
+	for _, traffic := range traffics {
+		if traffic.IsInbound {
+			err = tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
+				Updates(map[string]interface{}{
+					"up":   gorm.Expr("up + ?", traffic.Up),
+					"down": gorm.Expr("down + ?", traffic.Down),
+				}).Error
+			if err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) {
+	if len(traffics) == 0 {
+		return nil
+	}
 
 	emails := make([]string, 0, len(traffics))
 	for _, traffic := range traffics {
 		emails = append(emails, traffic.Email)
 	}
 	dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics))
-	err = db.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
+	err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
 	if err != nil {
 		return err
 	}
@@ -817,14 +809,13 @@ func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.Cl
 	return dbClientTraffics, nil
 }
 
-func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
-	db := database.GetDB()
+func (s *InboundService) disableInvalidInbounds(tx *gorm.DB) (bool, int64, error) {
 	now := time.Now().Unix() * 1000
 	needRestart := false
 
 	if p != nil {
 		var tags []string
-		err := db.Table("inbounds").
+		err := tx.Table("inbounds").
 			Select("inbounds.tag").
 			Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
 			Scan(&tags).Error
@@ -844,7 +835,7 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
 		s.xrayApi.Close()
 	}
 
-	result := db.Model(model.Inbound{}).
+	result := tx.Model(model.Inbound{}).
 		Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
 		Update("enable", false)
 	err := result.Error
@@ -852,8 +843,7 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
 	return needRestart, count, err
 }
 
-func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
-	db := database.GetDB()
+func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) {
 	now := time.Now().Unix() * 1000
 	needRestart := false
 
@@ -863,7 +853,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
 			Email string
 		}
 
-		err := db.Table("inbounds").
+		err := tx.Table("inbounds").
 			Select("inbounds.tag, client_traffics.email").
 			Joins("JOIN client_traffics ON inbounds.id = client_traffics.inbound_id").
 			Where("((client_traffics.total > 0 AND client_traffics.up + client_traffics.down >= client_traffics.total) OR (client_traffics.expiry_time > 0 AND client_traffics.expiry_time <= ?)) AND client_traffics.enable = ?", now, true).
@@ -883,7 +873,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
 		}
 		s.xrayApi.Close()
 	}
-	result := db.Model(xray.ClientTraffic{}).
+	result := tx.Model(xray.ClientTraffic{}).
 		Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
 		Update("enable", false)
 	err := result.Error

+ 1 - 1
web/service/xray.go

@@ -69,7 +69,7 @@ func (s *XrayService) GetXrayConfig() (*xray.Config, error) {
 		return nil, err
 	}
 
-	s.inboundService.DisableInvalidClients()
+	s.inboundService.AddTraffic(nil, nil)
 
 	inbounds, err := s.inboundService.GetAllInbounds()
 	if err != nil {

+ 0 - 3
web/web.go

@@ -247,9 +247,6 @@ func (s *Server) startTask() {
 		s.cron.AddJob("@every 10s", job.NewXrayTrafficJob())
 	}()
 
-	// Check the inbound traffic every 30 seconds that the traffic exceeds and expires
-	s.cron.AddJob("@every 30s", job.NewCheckInboundJob())
-
 	// check client ips from log file every 10 sec
 	s.cron.AddJob("@every 10s", job.NewCheckClientIpJob())