Browse Source

feat: Add MySQL database support (#3024)

* feat: Add MySQL database support

- Add MySQL database support with environment-based configuration
- Fix MySQL compatibility issue with 'key' column name
- Maintain SQLite as default database
- Add proper validation for MySQL configuration
- Test and verify compatibility with existing database
- Replaced raw SQL queries using JSON_EACH functions with standard GORM queries
- Modified functions to handle JSON parsing in Go code instead of database since JSON_EACH is not available on MySQL or MariaDB:
  - getAllEmails()
  - GetClientTrafficByID()
  - getFallbackMaster()
  - MigrationRemoveOrphanedTraffics()

The system now supports both MySQL and SQLite databases, with SQLite remaining as the default option. MySQL connection is only used when explicitly configured through environment variables.

* refactor: prefix env variables of database with XUI_ to support direct environment usage without .env file

All database configuration environment variables now start with the XUI_ prefix to avoid conflicts and allow configuration via system-level environment variables, not just the .env file.
Ali Golzar 10 tháng trước cách đây
mục cha
commit
3850e2f070
10 tập tin đã thay đổi với 239 bổ sung52 xóa
  1. 8 0
      .env.example
  2. 49 0
      config/config.go
  3. 29 9
      database/db.go
  4. 1 1
      database/model/model.go
  5. 2 0
      go.mod
  6. 5 0
      go.sum
  7. 1 0
      main.go
  8. 58 16
      sub/subService.go
  9. 84 24
      web/service/inbound.go
  10. 2 2
      web/service/setting.go

+ 8 - 0
.env.example

@@ -0,0 +1,8 @@
+XUI_DB_CONNECTION=sqlite
+
+# If DB connection is "mysql"
+# XUI_DB_HOST=127.0.0.1
+# XUI_DB_PORT=3306
+# XUI_DB_DATABASE=xui
+# XUI_DB_USERNAME=root
+# XUI_DB_PASSWORD=

+ 49 - 0
config/config.go

@@ -3,6 +3,7 @@ package config
 import (
 	_ "embed"
 	"fmt"
+	"log"
 	"os"
 	"strings"
 )
@@ -62,7 +63,55 @@ func GetDBFolderPath() string {
 	return dbFolderPath
 }
 
+// DatabaseConfig holds the database configuration
+type DatabaseConfig struct {
+	Connection string
+	Host       string
+	Port       string
+	Database   string
+	Username   string
+	Password   string
+}
+
+// GetDatabaseConfig returns the database configuration from environment variables
+func GetDatabaseConfig() (*DatabaseConfig, error) {
+	config := &DatabaseConfig{
+		Connection: strings.ToLower(os.Getenv("XUI_DB_CONNECTION")),
+		Host:       os.Getenv("XUI_DB_HOST"),
+		Port:       os.Getenv("XUI_DB_PORT"),
+		Database:   os.Getenv("XUI_DB_DATABASE"),
+		Username:   os.Getenv("XUI_DB_USERNAME"),
+		Password:   os.Getenv("XUI_DB_PASSWORD"),
+	}
+
+	if config.Connection == "mysql" {
+		if config.Host == "" || config.Database == "" || config.Username == "" {
+			return nil, fmt.Errorf("missing required MySQL configuration: host, database, and username are required")
+		}
+		if config.Port == "" {
+			config.Port = "3306"
+		}
+	}
+
+	return config, nil
+}
+
 func GetDBPath() string {
+	config, err := GetDatabaseConfig()
+	if err != nil {
+		log.Fatalf("Error getting database config: %v", err)
+	}
+
+	if config.Connection == "mysql" {
+		return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local",
+			config.Username,
+			config.Password,
+			config.Host,
+			config.Port,
+			config.Database)
+	}
+
+	// Connection is sqlite
 	return fmt.Sprintf("%s/%s.db", GetDBFolderPath(), GetName())
 }
 

+ 29 - 9
database/db.go

@@ -9,14 +9,15 @@ import (
 	"path"
 	"slices"
 
+	"gorm.io/driver/mysql"
+	"gorm.io/driver/sqlite"
+	"gorm.io/gorm"
+	"gorm.io/gorm/logger"
+
 	"x-ui/config"
 	"x-ui/database/model"
 	"x-ui/util/crypto"
 	"x-ui/xray"
-
-	"gorm.io/driver/sqlite"
-	"gorm.io/gorm"
-	"gorm.io/gorm/logger"
 )
 
 var db *gorm.DB
@@ -114,12 +115,22 @@ func isTableEmpty(tableName string) (bool, error) {
 }
 
 func InitDB(dbPath string) error {
-	dir := path.Dir(dbPath)
-	err := os.MkdirAll(dir, fs.ModePerm)
+	dbConfig, err := config.GetDatabaseConfig()
 	if err != nil {
 		return err
 	}
 
+	if dbConfig.Connection != "mysql" {
+		// Connection is sqlite
+		// Need to create the directory if it doesn't exist
+
+		dir := path.Dir(dbPath)
+		err = os.MkdirAll(dir, fs.ModePerm)
+		if err != nil {
+			return err
+		}
+	}
+
 	var gormLogger logger.Interface
 
 	if config.IsDebug() {
@@ -131,9 +142,18 @@ func InitDB(dbPath string) error {
 	c := &gorm.Config{
 		Logger: gormLogger,
 	}
-	db, err = gorm.Open(sqlite.Open(dbPath), c)
-	if err != nil {
-		return err
+
+	if dbConfig.Connection == "mysql" {
+		db, err = gorm.Open(mysql.Open(dbPath), c)
+		if err != nil {
+			return err
+		}
+	} else {
+		// Connection is sqlite
+		db, err = gorm.Open(sqlite.Open(dbPath), c)
+		if err != nil {
+			return err
+		}
 	}
 
 	if err := initModels(); err != nil {

+ 1 - 1
database/model/model.go

@@ -86,7 +86,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig {
 
 type Setting struct {
 	Id    int    `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
-	Key   string `json:"key" form:"key"`
+	Key   string `json:"key" form:"key" gorm:"column:key"`
 	Value string `json:"value" form:"value"`
 }
 

+ 2 - 0
go.mod

@@ -22,6 +22,7 @@ require (
 	golang.org/x/crypto v0.38.0
 	golang.org/x/text v0.25.0
 	google.golang.org/grpc v1.72.1
+	gorm.io/driver/mysql v1.5.7
 	gorm.io/driver/sqlite v1.5.7
 	gorm.io/gorm v1.25.12
 )
@@ -41,6 +42,7 @@ require (
 	github.com/go-playground/locales v0.14.1 // indirect
 	github.com/go-playground/universal-translator v0.18.1 // indirect
 	github.com/go-playground/validator/v10 v10.26.0 // indirect
+	github.com/go-sql-driver/mysql v1.7.0 // indirect
 	github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
 	github.com/google/btree v1.1.3 // indirect
 	github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect

+ 5 - 0
go.sum

@@ -51,6 +51,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
 github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
 github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
+github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
+github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
 github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
 github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
 github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
@@ -261,8 +263,11 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
 gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
+gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
 gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I=
 gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
+gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
 gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
 gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=

+ 1 - 0
main.go

@@ -105,6 +105,7 @@ func runWebServer() {
 		default:
 			server.Stop()
 			subServer.Stop()
+			database.CloseDB()
 			log.Println("Shutting down servers.")
 			return
 		}

+ 58 - 16
sub/subService.go

@@ -107,18 +107,30 @@ func (s *SubService) GetSubs(subId string, host string) ([]string, string, error
 func (s *SubService) getInboundsBySubId(subId string) ([]*model.Inbound, error) {
 	db := database.GetDB()
 	var inbounds []*model.Inbound
-	err := db.Model(model.Inbound{}).Preload("ClientStats").Where(`id in (
-		SELECT DISTINCT inbounds.id
-		FROM inbounds,
-			JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client 
-		WHERE
-			protocol in ('vmess','vless','trojan','shadowsocks')
-			AND JSON_EXTRACT(client.value, '$.subId') = ? AND enable = ?
-	)`, subId, true).Find(&inbounds).Error
+	err := db.Model(model.Inbound{}).
+		Preload("ClientStats").
+		Where("protocol IN ? AND enable = ?", []string{"vmess", "vless", "trojan", "shadowsocks"}, true).
+		Find(&inbounds).Error
 	if err != nil {
 		return nil, err
 	}
-	return inbounds, nil
+
+	// Filter inbounds that have clients with matching subId
+	var filteredInbounds []*model.Inbound
+	for _, inbound := range inbounds {
+		clients, err := s.inboundService.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.SubID == subId {
+				filteredInbounds = append(filteredInbounds, inbound)
+				break
+			}
+		}
+	}
+
+	return filteredInbounds, nil
 }
 
 func (s *SubService) getClientTraffics(traffics []xray.ClientTraffic, email string) xray.ClientTraffic {
@@ -132,25 +144,55 @@ func (s *SubService) getClientTraffics(traffics []xray.ClientTraffic, email stri
 
 func (s *SubService) getFallbackMaster(dest string, streamSettings string) (string, int, string, error) {
 	db := database.GetDB()
-	var inbound *model.Inbound
-	err := db.Model(model.Inbound{}).
-		Where("JSON_TYPE(settings, '$.fallbacks') = 'array'").
-		Where("EXISTS (SELECT * FROM json_each(settings, '$.fallbacks') WHERE json_extract(value, '$.dest') = ?)", dest).
-		Find(&inbound).Error
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
 	if err != nil {
 		return "", 0, "", err
 	}
 
+	// Find inbound with matching fallback dest
+	var masterInbound *model.Inbound
+	for _, inbound := range inbounds {
+		var settings map[string]any
+		err := json.Unmarshal([]byte(inbound.Settings), &settings)
+		if err != nil {
+			continue
+		}
+
+		fallbacks, ok := settings["fallbacks"].([]any)
+		if !ok {
+			continue
+		}
+
+		for _, fallback := range fallbacks {
+			f, ok := fallback.(map[string]any)
+			if !ok {
+				continue
+			}
+			if fallbackDest, ok := f["dest"].(string); ok && fallbackDest == dest {
+				masterInbound = inbound
+				break
+			}
+		}
+		if masterInbound != nil {
+			break
+		}
+	}
+
+	if masterInbound == nil {
+		return "", 0, "", fmt.Errorf("no inbound found with fallback dest: %s", dest)
+	}
+
 	var stream map[string]any
 	json.Unmarshal([]byte(streamSettings), &stream)
 	var masterStream map[string]any
-	json.Unmarshal([]byte(inbound.StreamSettings), &masterStream)
+	json.Unmarshal([]byte(masterInbound.StreamSettings), &masterStream)
 	stream["security"] = masterStream["security"]
 	stream["tlsSettings"] = masterStream["tlsSettings"]
 	stream["externalProxy"] = masterStream["externalProxy"]
 	modifiedStream, _ := json.MarshalIndent(stream, "", "  ")
 
-	return inbound.Listen, inbound.Port, string(modifiedStream), nil
+	return masterInbound.Listen, masterInbound.Port, string(modifiedStream), nil
 }
 
 func (s *SubService) getLink(inbound *model.Inbound, email string) string {

+ 84 - 24
web/service/inbound.go

@@ -87,15 +87,24 @@ func (s *InboundService) GetClients(inbound *model.Inbound) ([]model.Client, err
 
 func (s *InboundService) getAllEmails() ([]string, error) {
 	db := database.GetDB()
-	var emails []string
-	err := db.Raw(`
-		SELECT JSON_EXTRACT(client.value, '$.email')
-		FROM inbounds,
-			JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client
-		`).Scan(&emails).Error
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
 	if err != nil {
 		return nil, err
 	}
+
+	var emails []string
+	for _, inbound := range inbounds {
+		clients, err := s.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.Email != "" {
+				emails = append(emails, client.Email)
+			}
+		}
+	}
 	return emails, nil
 }
 
@@ -1120,14 +1129,46 @@ func (s *InboundService) GetInboundTags() (string, error) {
 
 func (s *InboundService) MigrationRemoveOrphanedTraffics() {
 	db := database.GetDB()
-	db.Exec(`
-		DELETE FROM client_traffics
-		WHERE email NOT IN (
-			SELECT JSON_EXTRACT(client.value, '$.email')
-			FROM inbounds,
-				JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client
-		)
-	`)
+
+	// Get all inbounds
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
+	if err != nil {
+		logger.Error("Failed to get inbounds:", err)
+		return
+	}
+
+	// Collect all valid emails from inbounds
+	validEmails := make(map[string]bool)
+	for _, inbound := range inbounds {
+		clients, err := s.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.Email != "" {
+				validEmails[client.Email] = true
+			}
+		}
+	}
+
+	// Get all client traffics
+	var traffics []xray.ClientTraffic
+	err = db.Model(xray.ClientTraffic{}).Find(&traffics).Error
+	if err != nil {
+		logger.Error("Failed to get client traffics:", err)
+		return
+	}
+
+	// Delete traffics with emails not in validEmails
+	for _, traffic := range traffics {
+		if !validEmails[traffic.Email] {
+			err = db.Delete(&traffic).Error
+			if err != nil {
+				logger.Error("Failed to delete orphaned traffic:", err)
+			}
+		}
+	}
 }
 
 func (s *InboundService) AddClientStat(tx *gorm.DB, inboundId int, client *model.Client) error {
@@ -1789,19 +1830,38 @@ func (s *InboundService) GetClientTrafficByID(id string) ([]xray.ClientTraffic,
 	db := database.GetDB()
 	var traffics []xray.ClientTraffic
 
-	err := db.Model(xray.ClientTraffic{}).Where(`email IN(
-		SELECT JSON_EXTRACT(client.value, '$.email') as email
-		FROM inbounds,
-	  	JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client
-		WHERE
-	  	JSON_EXTRACT(client.value, '$.id') in (?)
-		)`, id).Find(&traffics).Error
-
+	// First get all inbounds
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
 	if err != nil {
-		logger.Debug(err)
 		return nil, err
 	}
-	return traffics, err
+
+	// Collect all emails that match the ID
+	var targetEmails []string
+	for _, inbound := range inbounds {
+		clients, err := s.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.ID == id && client.Email != "" {
+				targetEmails = append(targetEmails, client.Email)
+			}
+		}
+	}
+	// Get traffics for the collected emails
+	if len(targetEmails) > 0 {
+		err = db.Model(xray.ClientTraffic{}).
+			Where("email IN ?", targetEmails).
+			Find(&traffics).Error
+		if err != nil {
+			logger.Debug(err)
+			return nil, err
+		}
+	}
+
+	return traffics, nil
 }
 
 func (s *InboundService) SearchClientTraffic(query string) (traffic *xray.ClientTraffic, err error) {

+ 2 - 2
web/service/setting.go

@@ -88,7 +88,7 @@ func (s *SettingService) GetDefaultJsonConfig() (any, error) {
 func (s *SettingService) GetAllSetting() (*entity.AllSetting, error) {
 	db := database.GetDB()
 	settings := make([]*model.Setting, 0)
-	err := db.Model(model.Setting{}).Not("key = ?", "xrayTemplateConfig").Find(&settings).Error
+	err := db.Model(model.Setting{}).Not("`key` = ?", "xrayTemplateConfig").Find(&settings).Error
 	if err != nil {
 		return nil, err
 	}
@@ -173,7 +173,7 @@ func (s *SettingService) ResetSettings() error {
 func (s *SettingService) getSetting(key string) (*model.Setting, error) {
 	db := database.GetDB()
 	setting := &model.Setting{}
-	err := db.Model(model.Setting{}).Where("key = ?", key).First(setting).Error
+	err := db.Model(model.Setting{}).Where("`key` = ?", key).First(setting).Error
 	if err != nil {
 		return nil, err
 	}