Browse Source

Optimize database

Co-Authored-By: Alireza Ahmadi <[email protected]>
MHSanaei 1 year ago
parent
commit
2b460bac1a
1 changed files with 82 additions and 95 deletions
  1. 82 95
      web/service/inbound.go

+ 82 - 95
web/service/inbound.go

@@ -418,45 +418,35 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, index int) err
 	return db.Save(oldInbound).Error
 }
 
-func (s *InboundService) AddTraffic(traffics []*xray.Traffic) (err error) {
+func (s *InboundService) AddTraffic(traffics []*xray.Traffic) error {
 	if len(traffics) == 0 {
 		return nil
 	}
-	db := database.GetDB()
-	db = db.Model(model.Inbound{})
-	tx := db.Begin()
-	defer func() {
-		if err != nil {
-			tx.Rollback()
-		} else {
-			tx.Commit()
-		}
-	}()
-	for _, traffic := range traffics {
-		if traffic.IsInbound {
-			err = tx.Where("tag = ?", traffic.Tag).
-				UpdateColumns(map[string]interface{}{
-					"up":   gorm.Expr("up + ?", traffic.Up),
-					"down": gorm.Expr("down + ?", traffic.Down)}).Error
-			if err != nil {
-				return
+	// Update traffics in a single transaction
+	err := database.GetDB().Transaction(func(tx *gorm.DB) error {
+		for _, traffic := range traffics {
+			if traffic.IsInbound {
+				update := tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
+					Updates(map[string]interface{}{
+						"up":   gorm.Expr("up + ?", traffic.Up),
+						"down": gorm.Expr("down + ?", traffic.Down),
+					})
+				if update.Error != nil {
+					return update.Error
+				}
 			}
 		}
-	}
-	return
+		return nil
+	})
+
+	return err
 }
 func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err error) {
 	if len(traffics) == 0 {
 		return nil
 	}
 
-	traffics, err = s.adjustTraffics(traffics)
-	if err != nil {
-		return err
-	}
-
 	db := database.GetDB()
-	db = db.Model(xray.ClientTraffic{})
 	tx := db.Begin()
 
 	defer func() {
@@ -467,7 +457,32 @@ func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err e
 		}
 	}()
 
-	err = tx.Save(traffics).Error
+	emails := make([]string, 0, len(traffics))
+	for _, traffic := range traffics {
+		emails = append(emails, traffic.Email)
+	}
+	dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics))
+	err = db.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
+	if err != nil {
+		return err
+	}
+
+	dbClientTraffics, err = s.adjustTraffics(tx, dbClientTraffics)
+	if err != nil {
+		return err
+	}
+
+	for dbTraffic_index := range dbClientTraffics {
+		for traffic_index := range traffics {
+			if dbClientTraffics[dbTraffic_index].Email == traffics[traffic_index].Email {
+				dbClientTraffics[dbTraffic_index].Up += traffics[traffic_index].Up
+				dbClientTraffics[dbTraffic_index].Down += traffics[traffic_index].Down
+				break
+			}
+		}
+	}
+
+	err = tx.Save(dbClientTraffics).Error
 	if err != nil {
 		logger.Warning("AddClientTraffic update data ", err)
 	}
@@ -475,81 +490,53 @@ func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err e
 	return nil
 }
 
-func (s *InboundService) adjustTraffics(traffics []*xray.ClientTraffic) (full_traffics []*xray.ClientTraffic, err error) {
-	db := database.GetDB()
-	dbInbound := db.Model(model.Inbound{})
-	txInbound := dbInbound.Begin()
-
-	defer func() {
-		if err != nil {
-			txInbound.Rollback()
-		} else {
-			txInbound.Commit()
-		}
-	}()
-
-	for _, traffic := range traffics {
-		inbound := &model.Inbound{}
-		client_traffic := &xray.ClientTraffic{}
-		err := db.Model(xray.ClientTraffic{}).Where("email = ?", traffic.Email).First(client_traffic).Error
-		if err != nil {
-			if err == gorm.ErrRecordNotFound {
-				logger.Warning(err, traffic.Email)
-			}
-			continue
+func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.ClientTraffic) ([]*xray.ClientTraffic, error) {
+	inboundIds := make([]int, 0, len(dbClientTraffics))
+	for _, dbClientTraffic := range dbClientTraffics {
+		if dbClientTraffic.ExpiryTime < 0 {
+			inboundIds = append(inboundIds, dbClientTraffic.InboundId)
 		}
-		client_traffic.Up += traffic.Up
-		client_traffic.Down += traffic.Down
+	}
 
-		err = txInbound.Where("id=?", client_traffic.InboundId).First(inbound).Error
+	if len(inboundIds) > 0 {
+		var inbounds []*model.Inbound
+		err := tx.Model(model.Inbound{}).Where("id IN (?)", inboundIds).Find(&inbounds).Error
 		if err != nil {
-			if err == gorm.ErrRecordNotFound {
-				logger.Warning(err, traffic.Email)
-			}
-			continue
+			return nil, err
 		}
-		// get clients
-		clients, err := s.getClients(inbound)
-		needUpdate := false
-		if err == nil {
-			for client_index, client := range clients {
-				if traffic.Email == client.Email {
-					if client.ExpiryTime < 0 {
-						clients[client_index].ExpiryTime = (time.Now().Unix() * 1000) - client.ExpiryTime
-						needUpdate = true
+		for inbound_index := range inbounds {
+			settings := map[string]interface{}{}
+			json.Unmarshal([]byte(inbounds[inbound_index].Settings), &settings)
+			clients, ok := settings["clients"].([]interface{})
+			if ok {
+				var newClients []interface{}
+				for client_index := range clients {
+					c := clients[client_index].(map[string]interface{})
+					for traffic_index := range dbClientTraffics {
+						if c["email"] == dbClientTraffics[traffic_index].Email {
+							oldExpiryTime := c["expiryTime"].(float64)
+							newExpiryTime := (time.Now().Unix() * 1000) - int64(oldExpiryTime)
+							c["expiryTime"] = newExpiryTime
+							dbClientTraffics[traffic_index].ExpiryTime = newExpiryTime
+							break
+						}
 					}
-					client_traffic.ExpiryTime = client.ExpiryTime
-					client_traffic.Total = client.TotalGB
-					break
+					newClients = append(newClients, interface{}(c))
 				}
-			}
-		}
-
-		if needUpdate {
-			settings := map[string]interface{}{}
-			json.Unmarshal([]byte(inbound.Settings), &settings)
-
-			// Convert clients to []interface to update clients in settings
-			var clientsInterface []interface{}
-			for _, c := range clients {
-				clientsInterface = append(clientsInterface, interface{}(c))
-			}
-
-			settings["clients"] = clientsInterface
-			modifiedSettings, err := json.MarshalIndent(settings, "", "  ")
-			if err != nil {
-				return nil, err
-			}
-
-			err = txInbound.Where("id=?", inbound.Id).Update("settings", string(modifiedSettings)).Error
-			if err != nil {
-				return nil, err
-			}
+				settings["clients"] = newClients
+				modifiedSettings, err := json.MarshalIndent(settings, "", "  ")
+				if err != nil {
+					return nil, err
+				}
+				inbounds[inbound_index].Settings = string(modifiedSettings)
 		}
-
-		full_traffics = append(full_traffics, client_traffic)
+		err = tx.Save(inbounds).Error
+		if err != nil {
+			logger.Warning("AddClientTraffic update inbounds ", err)
+			logger.Error(inbounds)
+		}	
 	}
-	return full_traffics, nil
+	return dbClientTraffics, nil
 }
 
 func (s *InboundService) DisableInvalidInbounds() (int64, error) {