1
0

db.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package database
  2. import (
  3. "bytes"
  4. "io"
  5. "io/fs"
  6. "log"
  7. "os"
  8. "path"
  9. "x-ui/config"
  10. "x-ui/database/model"
  11. "x-ui/xray"
  12. "gorm.io/driver/sqlite"
  13. "gorm.io/gorm"
  14. "gorm.io/gorm/logger"
  15. )
  16. var db *gorm.DB
  17. const (
  18. defaultUsername = "admin"
  19. defaultPassword = "admin"
  20. defaultSecret = ""
  21. )
  22. func initModels() error {
  23. // Order matters: first create tables without dependencies
  24. baseModels := []interface{}{
  25. &model.User{},
  26. &model.Setting{},
  27. }
  28. // Migrate base models
  29. for _, model := range baseModels {
  30. if err := db.AutoMigrate(model); err != nil {
  31. log.Printf("Error auto migrating base model: %v", err)
  32. return err
  33. }
  34. }
  35. // Then migrate models with dependencies
  36. dependentModels := []interface{}{
  37. &model.Inbound{},
  38. &model.OutboundTraffics{},
  39. &model.InboundClientIps{},
  40. &xray.ClientTraffic{},
  41. }
  42. for _, model := range dependentModels {
  43. if err := db.AutoMigrate(model); err != nil {
  44. log.Printf("Error auto migrating dependent model: %v", err)
  45. return err
  46. }
  47. }
  48. return nil
  49. }
  50. func initUser() error {
  51. empty, err := isTableEmpty("users")
  52. if err != nil {
  53. log.Printf("Error checking if users table is empty: %v", err)
  54. return err
  55. }
  56. if empty {
  57. user := &model.User{
  58. Username: defaultUsername,
  59. Password: defaultPassword,
  60. LoginSecret: defaultSecret,
  61. }
  62. return db.Create(user).Error
  63. }
  64. return nil
  65. }
  66. func isTableEmpty(tableName string) (bool, error) {
  67. var count int64
  68. err := db.Table(tableName).Count(&count).Error
  69. return count == 0, err
  70. }
  71. func InitDB(dbPath string) error {
  72. dir := path.Dir(dbPath)
  73. err := os.MkdirAll(dir, fs.ModePerm)
  74. if err != nil {
  75. return err
  76. }
  77. var gormLogger logger.Interface
  78. if config.IsDebug() {
  79. gormLogger = logger.Default
  80. } else {
  81. gormLogger = logger.Discard
  82. }
  83. c := &gorm.Config{
  84. Logger: gormLogger,
  85. SkipDefaultTransaction: true,
  86. PrepareStmt: true,
  87. }
  88. dsn := dbPath + "?cache=shared&_journal_mode=WAL&_synchronous=NORMAL"
  89. db, err = gorm.Open(sqlite.Open(dsn), c)
  90. if err != nil {
  91. return err
  92. }
  93. sqlDB, err := db.DB()
  94. if err != nil {
  95. return err
  96. }
  97. _, err = sqlDB.Exec("PRAGMA cache_size = -64000;")
  98. if err != nil {
  99. return err
  100. }
  101. _, err = sqlDB.Exec("PRAGMA temp_store = MEMORY;")
  102. if err != nil {
  103. return err
  104. }
  105. _, err = sqlDB.Exec("PRAGMA foreign_keys = ON;")
  106. if err != nil {
  107. return err
  108. }
  109. if err := initModels(); err != nil {
  110. return err
  111. }
  112. if err := initUser(); err != nil {
  113. return err
  114. }
  115. return nil
  116. }
  117. func CloseDB() error {
  118. if db != nil {
  119. if err := Checkpoint(); err != nil {
  120. log.Printf("error executing checkpoint: %v", err)
  121. }
  122. sqlDB, err := db.DB()
  123. if err != nil {
  124. return err
  125. }
  126. return sqlDB.Close()
  127. }
  128. return nil
  129. }
  130. func GetDB() *gorm.DB {
  131. return db
  132. }
  133. func IsNotFound(err error) bool {
  134. return err == gorm.ErrRecordNotFound
  135. }
  136. func IsSQLiteDB(file io.ReaderAt) (bool, error) {
  137. signature := []byte("SQLite format 3\x00")
  138. buf := make([]byte, len(signature))
  139. _, err := file.ReadAt(buf, 0)
  140. if err != nil {
  141. return false, err
  142. }
  143. return bytes.Equal(buf, signature), nil
  144. }
  145. func Checkpoint() error {
  146. // Update WAL
  147. err := db.Exec("PRAGMA wal_checkpoint;").Error
  148. if err != nil {
  149. return err
  150. }
  151. return nil
  152. }