|
@@ -4,6 +4,7 @@ import (
|
|
|
"bytes"
|
|
|
"io"
|
|
|
"io/fs"
|
|
|
+ "log"
|
|
|
"os"
|
|
|
"path"
|
|
|
|
|
@@ -18,54 +19,51 @@ import (
|
|
|
|
|
|
var db *gorm.DB
|
|
|
|
|
|
-var initializers = []func() error{
|
|
|
- initUser,
|
|
|
- initInbound,
|
|
|
- initOutbound,
|
|
|
- initSetting,
|
|
|
- initInboundClientIps,
|
|
|
- initClientTraffic,
|
|
|
+const (
|
|
|
+ defaultUsername = "admin"
|
|
|
+ defaultPassword = "admin"
|
|
|
+ defaultSecret = ""
|
|
|
+)
|
|
|
+
|
|
|
+func initModels() error {
|
|
|
+ models := []interface{}{
|
|
|
+ &model.User{},
|
|
|
+ &model.Inbound{},
|
|
|
+ &model.OutboundTraffics{},
|
|
|
+ &model.Setting{},
|
|
|
+ &model.InboundClientIps{},
|
|
|
+ &xray.ClientTraffic{},
|
|
|
+ }
|
|
|
+ for _, model := range models {
|
|
|
+ if err := db.AutoMigrate(model); err != nil {
|
|
|
+ log.Printf("Error auto migrating model: %v", err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func initUser() error {
|
|
|
- err := db.AutoMigrate(&model.User{})
|
|
|
+ empty, err := isTableEmpty("users")
|
|
|
if err != nil {
|
|
|
+ log.Printf("Error checking if users table is empty: %v", err)
|
|
|
return err
|
|
|
}
|
|
|
- var count int64
|
|
|
- err = db.Model(&model.User{}).Count(&count).Error
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- if count == 0 {
|
|
|
+ if empty {
|
|
|
user := &model.User{
|
|
|
- Username: "admin",
|
|
|
- Password: "admin",
|
|
|
- LoginSecret: "",
|
|
|
+ Username: defaultUsername,
|
|
|
+ Password: defaultPassword,
|
|
|
+ LoginSecret: defaultSecret,
|
|
|
}
|
|
|
return db.Create(user).Error
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func initInbound() error {
|
|
|
- return db.AutoMigrate(&model.Inbound{})
|
|
|
-}
|
|
|
-
|
|
|
-func initOutbound() error {
|
|
|
- return db.AutoMigrate(&model.OutboundTraffics{})
|
|
|
-}
|
|
|
-
|
|
|
-func initSetting() error {
|
|
|
- return db.AutoMigrate(&model.Setting{})
|
|
|
-}
|
|
|
-
|
|
|
-func initInboundClientIps() error {
|
|
|
- return db.AutoMigrate(&model.InboundClientIps{})
|
|
|
-}
|
|
|
-
|
|
|
-func initClientTraffic() error {
|
|
|
- return db.AutoMigrate(&xray.ClientTraffic{})
|
|
|
+func isTableEmpty(tableName string) (bool, error) {
|
|
|
+ var count int64
|
|
|
+ err := db.Table(tableName).Count(&count).Error
|
|
|
+ return count == 0, err
|
|
|
}
|
|
|
|
|
|
func InitDB(dbPath string) error {
|
|
@@ -91,12 +89,24 @@ func InitDB(dbPath string) error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- for _, initialize := range initializers {
|
|
|
- if err := initialize(); err != nil {
|
|
|
+ if err := initModels(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if err := initUser(); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func CloseDB() error {
|
|
|
+ if db != nil {
|
|
|
+ sqlDB, err := db.DB()
|
|
|
+ if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ return sqlDB.Close()
|
|
|
}
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
|