1
0
Эх сурвалжийг харах

feat: Add MySQL database support (#3024)

* feat: Add MySQL database support

- Add MySQL database support with environment-based configuration
- Fix MySQL compatibility issue with 'key' column name
- Maintain SQLite as default database
- Add proper validation for MySQL configuration
- Test and verify compatibility with existing database
- Replaced raw SQL queries using JSON_EACH functions with standard GORM queries
- Modified functions to handle JSON parsing in Go code instead of database since JSON_EACH is not available on MySQL or MariaDB:
  - getAllEmails()
  - GetClientTrafficByID()
  - getFallbackMaster()
  - MigrationRemoveOrphanedTraffics()

The system now supports both MySQL and SQLite databases, with SQLite remaining as the default option. MySQL connection is only used when explicitly configured through environment variables.

* refactor: prefix env variables of database with XUI_ to support direct environment usage without .env file

All database configuration environment variables now start with the XUI_ prefix to avoid conflicts and allow configuration via system-level environment variables, not just the .env file.
Ali Golzar 2 долоо хоног өмнө
parent
commit
3850e2f070

+ 8 - 0
.env.example

@@ -0,0 +1,8 @@
+XUI_DB_CONNECTION=sqlite
+
+# If DB connection is "mysql"
+# XUI_DB_HOST=127.0.0.1
+# XUI_DB_PORT=3306
+# XUI_DB_DATABASE=xui
+# XUI_DB_USERNAME=root
+# XUI_DB_PASSWORD=

+ 49 - 0
config/config.go

@@ -3,6 +3,7 @@ package config
 import (
 	_ "embed"
 	"fmt"
+	"log"
 	"os"
 	"strings"
 )
@@ -62,7 +63,55 @@ func GetDBFolderPath() string {
 	return dbFolderPath
 }
 
+// DatabaseConfig holds the database configuration
+type DatabaseConfig struct {
+	Connection string
+	Host       string
+	Port       string
+	Database   string
+	Username   string
+	Password   string
+}
+
+// GetDatabaseConfig returns the database configuration from environment variables
+func GetDatabaseConfig() (*DatabaseConfig, error) {
+	config := &DatabaseConfig{
+		Connection: strings.ToLower(os.Getenv("XUI_DB_CONNECTION")),
+		Host:       os.Getenv("XUI_DB_HOST"),
+		Port:       os.Getenv("XUI_DB_PORT"),
+		Database:   os.Getenv("XUI_DB_DATABASE"),
+		Username:   os.Getenv("XUI_DB_USERNAME"),
+		Password:   os.Getenv("XUI_DB_PASSWORD"),
+	}
+
+	if config.Connection == "mysql" {
+		if config.Host == "" || config.Database == "" || config.Username == "" {
+			return nil, fmt.Errorf("missing required MySQL configuration: host, database, and username are required")
+		}
+		if config.Port == "" {
+			config.Port = "3306"
+		}
+	}
+
+	return config, nil
+}
+
 func GetDBPath() string {
+	config, err := GetDatabaseConfig()
+	if err != nil {
+		log.Fatalf("Error getting database config: %v", err)
+	}
+
+	if config.Connection == "mysql" {
+		return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local",
+			config.Username,
+			config.Password,
+			config.Host,
+			config.Port,
+			config.Database)
+	}
+
+	// Connection is sqlite
 	return fmt.Sprintf("%s/%s.db", GetDBFolderPath(), GetName())
 }
 

+ 29 - 9
database/db.go

@@ -9,14 +9,15 @@ import (
 	"path"
 	"slices"
 
+	"gorm.io/driver/mysql"
+	"gorm.io/driver/sqlite"
+	"gorm.io/gorm"
+	"gorm.io/gorm/logger"
+
 	"x-ui/config"
 	"x-ui/database/model"
 	"x-ui/util/crypto"
 	"x-ui/xray"
-
-	"gorm.io/driver/sqlite"
-	"gorm.io/gorm"
-	"gorm.io/gorm/logger"
 )
 
 var db *gorm.DB
@@ -114,12 +115,22 @@ func isTableEmpty(tableName string) (bool, error) {
 }
 
 func InitDB(dbPath string) error {
-	dir := path.Dir(dbPath)
-	err := os.MkdirAll(dir, fs.ModePerm)
+	dbConfig, err := config.GetDatabaseConfig()
 	if err != nil {
 		return err
 	}
 
+	if dbConfig.Connection != "mysql" {
+		// Connection is sqlite
+		// Need to create the directory if it doesn't exist
+
+		dir := path.Dir(dbPath)
+		err = os.MkdirAll(dir, fs.ModePerm)
+		if err != nil {
+			return err
+		}
+	}
+
 	var gormLogger logger.Interface
 
 	if config.IsDebug() {
@@ -131,9 +142,18 @@ func InitDB(dbPath string) error {
 	c := &gorm.Config{
 		Logger: gormLogger,
 	}
-	db, err = gorm.Open(sqlite.Open(dbPath), c)
-	if err != nil {
-		return err
+
+	if dbConfig.Connection == "mysql" {
+		db, err = gorm.Open(mysql.Open(dbPath), c)
+		if err != nil {
+			return err
+		}
+	} else {
+		// Connection is sqlite
+		db, err = gorm.Open(sqlite.Open(dbPath), c)
+		if err != nil {
+			return err
+		}
 	}
 
 	if err := initModels(); err != nil {

+ 1 - 1
database/model/model.go

@@ -86,7 +86,7 @@ func (i *Inbound) GenXrayInboundConfig() *xray.InboundConfig {
 
 type Setting struct {
 	Id    int    `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
-	Key   string `json:"key" form:"key"`
+	Key   string `json:"key" form:"key" gorm:"column:key"`
 	Value string `json:"value" form:"value"`
 }
 

+ 2 - 0
go.mod

@@ -22,6 +22,7 @@ require (
 	golang.org/x/crypto v0.38.0
 	golang.org/x/text v0.25.0
 	google.golang.org/grpc v1.72.1
+	gorm.io/driver/mysql v1.5.7
 	gorm.io/driver/sqlite v1.5.7
 	gorm.io/gorm v1.25.12
 )
@@ -41,6 +42,7 @@ require (
 	github.com/go-playground/locales v0.14.1 // indirect
 	github.com/go-playground/universal-translator v0.18.1 // indirect
 	github.com/go-playground/validator/v10 v10.26.0 // indirect
+	github.com/go-sql-driver/mysql v1.7.0 // indirect
 	github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
 	github.com/google/btree v1.1.3 // indirect
 	github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect

+ 5 - 0
go.sum

@@ -51,6 +51,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
 github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
 github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
+github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
+github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
 github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
 github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
 github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
@@ -261,8 +263,11 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
 gopkg.in/yaml.v3 v3.0.0-20200605160147-a5ece683394c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
+gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
 gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I=
 gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
+gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
 gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
 gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
 gvisor.dev/gvisor v0.0.0-20250428193742-2d800c3129d5 h1:sfK5nHuG7lRFZ2FdTT3RimOqWBg8IrVm+/Vko1FVOsk=

+ 1 - 0
main.go

@@ -105,6 +105,7 @@ func runWebServer() {
 		default:
 			server.Stop()
 			subServer.Stop()
+			database.CloseDB()
 			log.Println("Shutting down servers.")
 			return
 		}

+ 58 - 16
sub/subService.go

@@ -107,18 +107,30 @@ func (s *SubService) GetSubs(subId string, host string) ([]string, string, error
 func (s *SubService) getInboundsBySubId(subId string) ([]*model.Inbound, error) {
 	db := database.GetDB()
 	var inbounds []*model.Inbound
-	err := db.Model(model.Inbound{}).Preload("ClientStats").Where(`id in (
-		SELECT DISTINCT inbounds.id
-		FROM inbounds,
-			JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client 
-		WHERE
-			protocol in ('vmess','vless','trojan','shadowsocks')
-			AND JSON_EXTRACT(client.value, '$.subId') = ? AND enable = ?
-	)`, subId, true).Find(&inbounds).Error
+	err := db.Model(model.Inbound{}).
+		Preload("ClientStats").
+		Where("protocol IN ? AND enable = ?", []string{"vmess", "vless", "trojan", "shadowsocks"}, true).
+		Find(&inbounds).Error
 	if err != nil {
 		return nil, err
 	}
-	return inbounds, nil
+
+	// Filter inbounds that have clients with matching subId
+	var filteredInbounds []*model.Inbound
+	for _, inbound := range inbounds {
+		clients, err := s.inboundService.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.SubID == subId {
+				filteredInbounds = append(filteredInbounds, inbound)
+				break
+			}
+		}
+	}
+
+	return filteredInbounds, nil
 }
 
 func (s *SubService) getClientTraffics(traffics []xray.ClientTraffic, email string) xray.ClientTraffic {
@@ -132,25 +144,55 @@ func (s *SubService) getClientTraffics(traffics []xray.ClientTraffic, email stri
 
 func (s *SubService) getFallbackMaster(dest string, streamSettings string) (string, int, string, error) {
 	db := database.GetDB()
-	var inbound *model.Inbound
-	err := db.Model(model.Inbound{}).
-		Where("JSON_TYPE(settings, '$.fallbacks') = 'array'").
-		Where("EXISTS (SELECT * FROM json_each(settings, '$.fallbacks') WHERE json_extract(value, '$.dest') = ?)", dest).
-		Find(&inbound).Error
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
 	if err != nil {
 		return "", 0, "", err
 	}
 
+	// Find inbound with matching fallback dest
+	var masterInbound *model.Inbound
+	for _, inbound := range inbounds {
+		var settings map[string]any
+		err := json.Unmarshal([]byte(inbound.Settings), &settings)
+		if err != nil {
+			continue
+		}
+
+		fallbacks, ok := settings["fallbacks"].([]any)
+		if !ok {
+			continue
+		}
+
+		for _, fallback := range fallbacks {
+			f, ok := fallback.(map[string]any)
+			if !ok {
+				continue
+			}
+			if fallbackDest, ok := f["dest"].(string); ok && fallbackDest == dest {
+				masterInbound = inbound
+				break
+			}
+		}
+		if masterInbound != nil {
+			break
+		}
+	}
+
+	if masterInbound == nil {
+		return "", 0, "", fmt.Errorf("no inbound found with fallback dest: %s", dest)
+	}
+
 	var stream map[string]any
 	json.Unmarshal([]byte(streamSettings), &stream)
 	var masterStream map[string]any
-	json.Unmarshal([]byte(inbound.StreamSettings), &masterStream)
+	json.Unmarshal([]byte(masterInbound.StreamSettings), &masterStream)
 	stream["security"] = masterStream["security"]
 	stream["tlsSettings"] = masterStream["tlsSettings"]
 	stream["externalProxy"] = masterStream["externalProxy"]
 	modifiedStream, _ := json.MarshalIndent(stream, "", "  ")
 
-	return inbound.Listen, inbound.Port, string(modifiedStream), nil
+	return masterInbound.Listen, masterInbound.Port, string(modifiedStream), nil
 }
 
 func (s *SubService) getLink(inbound *model.Inbound, email string) string {

+ 84 - 24
web/service/inbound.go

@@ -87,15 +87,24 @@ func (s *InboundService) GetClients(inbound *model.Inbound) ([]model.Client, err
 
 func (s *InboundService) getAllEmails() ([]string, error) {
 	db := database.GetDB()
-	var emails []string
-	err := db.Raw(`
-		SELECT JSON_EXTRACT(client.value, '$.email')
-		FROM inbounds,
-			JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client
-		`).Scan(&emails).Error
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
 	if err != nil {
 		return nil, err
 	}
+
+	var emails []string
+	for _, inbound := range inbounds {
+		clients, err := s.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.Email != "" {
+				emails = append(emails, client.Email)
+			}
+		}
+	}
 	return emails, nil
 }
 
@@ -1120,14 +1129,46 @@ func (s *InboundService) GetInboundTags() (string, error) {
 
 func (s *InboundService) MigrationRemoveOrphanedTraffics() {
 	db := database.GetDB()
-	db.Exec(`
-		DELETE FROM client_traffics
-		WHERE email NOT IN (
-			SELECT JSON_EXTRACT(client.value, '$.email')
-			FROM inbounds,
-				JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client
-		)
-	`)
+
+	// Get all inbounds
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
+	if err != nil {
+		logger.Error("Failed to get inbounds:", err)
+		return
+	}
+
+	// Collect all valid emails from inbounds
+	validEmails := make(map[string]bool)
+	for _, inbound := range inbounds {
+		clients, err := s.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.Email != "" {
+				validEmails[client.Email] = true
+			}
+		}
+	}
+
+	// Get all client traffics
+	var traffics []xray.ClientTraffic
+	err = db.Model(xray.ClientTraffic{}).Find(&traffics).Error
+	if err != nil {
+		logger.Error("Failed to get client traffics:", err)
+		return
+	}
+
+	// Delete traffics with emails not in validEmails
+	for _, traffic := range traffics {
+		if !validEmails[traffic.Email] {
+			err = db.Delete(&traffic).Error
+			if err != nil {
+				logger.Error("Failed to delete orphaned traffic:", err)
+			}
+		}
+	}
 }
 
 func (s *InboundService) AddClientStat(tx *gorm.DB, inboundId int, client *model.Client) error {
@@ -1789,19 +1830,38 @@ func (s *InboundService) GetClientTrafficByID(id string) ([]xray.ClientTraffic,
 	db := database.GetDB()
 	var traffics []xray.ClientTraffic
 
-	err := db.Model(xray.ClientTraffic{}).Where(`email IN(
-		SELECT JSON_EXTRACT(client.value, '$.email') as email
-		FROM inbounds,
-	  	JSON_EACH(JSON_EXTRACT(inbounds.settings, '$.clients')) AS client
-		WHERE
-	  	JSON_EXTRACT(client.value, '$.id') in (?)
-		)`, id).Find(&traffics).Error
-
+	// First get all inbounds
+	var inbounds []*model.Inbound
+	err := db.Model(model.Inbound{}).Find(&inbounds).Error
 	if err != nil {
-		logger.Debug(err)
 		return nil, err
 	}
-	return traffics, err
+
+	// Collect all emails that match the ID
+	var targetEmails []string
+	for _, inbound := range inbounds {
+		clients, err := s.GetClients(inbound)
+		if err != nil {
+			continue
+		}
+		for _, client := range clients {
+			if client.ID == id && client.Email != "" {
+				targetEmails = append(targetEmails, client.Email)
+			}
+		}
+	}
+	// Get traffics for the collected emails
+	if len(targetEmails) > 0 {
+		err = db.Model(xray.ClientTraffic{}).
+			Where("email IN ?", targetEmails).
+			Find(&traffics).Error
+		if err != nil {
+			logger.Debug(err)
+			return nil, err
+		}
+	}
+
+	return traffics, nil
 }
 
 func (s *InboundService) SearchClientTraffic(query string) (traffic *xray.ClientTraffic, err error) {

+ 2 - 2
web/service/setting.go

@@ -88,7 +88,7 @@ func (s *SettingService) GetDefaultJsonConfig() (any, error) {
 func (s *SettingService) GetAllSetting() (*entity.AllSetting, error) {
 	db := database.GetDB()
 	settings := make([]*model.Setting, 0)
-	err := db.Model(model.Setting{}).Not("key = ?", "xrayTemplateConfig").Find(&settings).Error
+	err := db.Model(model.Setting{}).Not("`key` = ?", "xrayTemplateConfig").Find(&settings).Error
 	if err != nil {
 		return nil, err
 	}
@@ -173,7 +173,7 @@ func (s *SettingService) ResetSettings() error {
 func (s *SettingService) getSetting(key string) (*model.Setting, error) {
 	db := database.GetDB()
 	setting := &model.Setting{}
-	err := db.Model(model.Setting{}).Where("key = ?", key).First(setting).Error
+	err := db.Model(model.Setting{}).Where("`key` = ?", key).First(setting).Error
 	if err != nil {
 		return nil, err
 	}