1
0
Эх сурвалжийг харах

Refactor database initialization

mhsanaei 4 сар өмнө
parent
commit
dfe0bbd371
1 өөрчлөгдсөн 48 нэмэгдсэн , 38 устгасан
  1. 48 38
      database/db.go

+ 48 - 38
database/db.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"io"
 	"io/fs"
+	"log"
 	"os"
 	"path"
 
@@ -18,54 +19,51 @@ import (
 
 var db *gorm.DB
 
-var initializers = []func() error{
-	initUser,
-	initInbound,
-	initOutbound,
-	initSetting,
-	initInboundClientIps,
-	initClientTraffic,
+const (
+	defaultUsername = "admin"
+	defaultPassword = "admin"
+	defaultSecret   = ""
+)
+
+func initModels() error {
+	models := []interface{}{
+		&model.User{},
+		&model.Inbound{},
+		&model.OutboundTraffics{},
+		&model.Setting{},
+		&model.InboundClientIps{},
+		&xray.ClientTraffic{},
+	}
+	for _, model := range models {
+		if err := db.AutoMigrate(model); err != nil {
+			log.Printf("Error auto migrating model: %v", err)
+			return err
+		}
+	}
+	return nil
 }
 
 func initUser() error {
-	err := db.AutoMigrate(&model.User{})
+	empty, err := isTableEmpty("users")
 	if err != nil {
+		log.Printf("Error checking if users table is empty: %v", err)
 		return err
 	}
-	var count int64
-	err = db.Model(&model.User{}).Count(&count).Error
-	if err != nil {
-		return err
-	}
-	if count == 0 {
+	if empty {
 		user := &model.User{
-			Username:    "admin",
-			Password:    "admin",
-			LoginSecret: "",
+			Username:    defaultUsername,
+			Password:    defaultPassword,
+			LoginSecret: defaultSecret,
 		}
 		return db.Create(user).Error
 	}
 	return nil
 }
 
-func initInbound() error {
-	return db.AutoMigrate(&model.Inbound{})
-}
-
-func initOutbound() error {
-	return db.AutoMigrate(&model.OutboundTraffics{})
-}
-
-func initSetting() error {
-	return db.AutoMigrate(&model.Setting{})
-}
-
-func initInboundClientIps() error {
-	return db.AutoMigrate(&model.InboundClientIps{})
-}
-
-func initClientTraffic() error {
-	return db.AutoMigrate(&xray.ClientTraffic{})
+func isTableEmpty(tableName string) (bool, error) {
+	var count int64
+	err := db.Table(tableName).Count(&count).Error
+	return count == 0, err
 }
 
 func InitDB(dbPath string) error {
@@ -91,12 +89,24 @@ func InitDB(dbPath string) error {
 		return err
 	}
 
-	for _, initialize := range initializers {
-		if err := initialize(); err != nil {
+	if err := initModels(); err != nil {
+		return err
+	}
+	if err := initUser(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func CloseDB() error {
+	if db != nil {
+		sqlDB, err := db.DB()
+		if err != nil {
 			return err
 		}
+		return sqlDB.Close()
 	}
-
 	return nil
 }