1
0

migrate_data.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package database
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "os"
  8. "path"
  9. "reflect"
  10. "strings"
  11. "time"
  12. "github.com/mhsanaei/3x-ui/v3/database/model"
  13. "github.com/mhsanaei/3x-ui/v3/xray"
  14. "gorm.io/driver/postgres"
  15. "gorm.io/driver/sqlite"
  16. "gorm.io/gorm"
  17. "gorm.io/gorm/logger"
  18. )
  19. // migrationModels is the FK-aware order in which tables are created and copied.
  20. // Parents come before their children so foreign-key constraints stay satisfied
  21. // even when checks are not explicitly disabled.
  22. func migrationModels() []any {
  23. return []any{
  24. &model.User{},
  25. &model.Setting{},
  26. &model.HistoryOfSeeders{},
  27. &model.CustomGeoResource{},
  28. &model.Node{},
  29. &model.ApiToken{},
  30. &model.Inbound{},
  31. &xray.ClientTraffic{},
  32. &model.OutboundTraffics{},
  33. &model.InboundClientIps{},
  34. &model.ClientRecord{},
  35. &model.ClientInbound{},
  36. &model.InboundFallback{},
  37. &model.NodeClientTraffic{},
  38. }
  39. }
  40. // MigrateData copies every row from the configured SQLite file at srcPath into
  41. // a fresh PostgreSQL database described by dstDSN. The destination tables are
  42. // (re)created with AutoMigrate before the copy. Source data is left untouched.
  43. func MigrateData(srcPath, dstDSN string) error {
  44. if _, err := os.Stat(srcPath); err != nil {
  45. return fmt.Errorf("source sqlite not found at %s: %w", srcPath, err)
  46. }
  47. if dstDSN == "" {
  48. return errors.New("destination DSN is required")
  49. }
  50. if err := os.MkdirAll(path.Dir(srcPath), 0755); err != nil {
  51. return err
  52. }
  53. srcDSN := srcPath + "?_journal_mode=WAL&_busy_timeout=10000"
  54. src, err := gorm.Open(sqlite.Open(srcDSN), &gorm.Config{Logger: logger.Discard})
  55. if err != nil {
  56. return fmt.Errorf("open sqlite source: %w", err)
  57. }
  58. srcSQL, err := src.DB()
  59. if err != nil {
  60. return err
  61. }
  62. defer srcSQL.Close()
  63. dst, err := gorm.Open(postgres.Open(dstDSN), &gorm.Config{Logger: logger.Discard})
  64. if err != nil {
  65. return fmt.Errorf("open postgres destination: %w", err)
  66. }
  67. dstSQL, err := dst.DB()
  68. if err != nil {
  69. return err
  70. }
  71. defer dstSQL.Close()
  72. dstSQL.SetConnMaxLifetime(time.Hour)
  73. log.Println("Creating destination schema...")
  74. for _, m := range migrationModels() {
  75. if err := dst.AutoMigrate(m); err != nil {
  76. return fmt.Errorf("AutoMigrate %T: %w", m, err)
  77. }
  78. }
  79. totalRows := 0
  80. for _, m := range migrationModels() {
  81. n, err := copyTable(src, dst, m)
  82. if err != nil {
  83. return fmt.Errorf("copy %T: %w", m, err)
  84. }
  85. totalRows += n
  86. log.Printf(" %-32s %d rows", reflect.TypeOf(m).Elem().Name(), n)
  87. }
  88. if err := resetPostgresSequences(dst); err != nil {
  89. log.Printf("warning: failed to reset some postgres sequences: %v", err)
  90. }
  91. log.Printf("Migration complete: %d rows across %d tables.", totalRows, len(migrationModels()))
  92. log.Println("Set XUI_DB_TYPE=postgres and XUI_DB_DSN=... in /etc/default/x-ui, then restart x-ui.")
  93. return nil
  94. }
  95. func copyTable(src, dst *gorm.DB, mdl any) (int, error) {
  96. const batchSize = 500
  97. sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(mdl).Elem()))
  98. stmt := &gorm.Statement{DB: src}
  99. if err := stmt.Parse(mdl); err != nil {
  100. return 0, err
  101. }
  102. order := strings.Join(stmt.Schema.PrimaryFieldDBNames, ", ")
  103. table := stmt.Schema.Table
  104. columns := stmt.Schema.DBNames
  105. ctx := context.Background()
  106. total := 0
  107. for offset := 0; ; offset += batchSize {
  108. batchPtr := reflect.New(sliceType)
  109. q := src.Model(mdl).Limit(batchSize).Offset(offset)
  110. if order != "" {
  111. q = q.Order(order)
  112. }
  113. if err := q.Find(batchPtr.Interface()).Error; err != nil {
  114. return total, err
  115. }
  116. slice := batchPtr.Elem()
  117. n := slice.Len()
  118. if n == 0 {
  119. break
  120. }
  121. rows := make([]map[string]any, n)
  122. for i := 0; i < n; i++ {
  123. rv := reflect.Indirect(slice.Index(i))
  124. row := make(map[string]any, len(columns))
  125. for _, name := range columns {
  126. value, _ := stmt.Schema.FieldsByDBName[name].ValueOf(ctx, rv)
  127. row[name] = value
  128. }
  129. rows[i] = row
  130. }
  131. if err := dst.Table(table).CreateInBatches(rows, 200).Error; err != nil {
  132. return total, err
  133. }
  134. total += n
  135. if n < batchSize {
  136. break
  137. }
  138. }
  139. return total, nil
  140. }
  141. // resetPostgresSequences advances each migrated table's id sequence past MAX(id),
  142. // otherwise the next INSERT-without-id would clash with copied rows.
  143. func resetPostgresSequences(dst *gorm.DB) error {
  144. return resyncPostgresSequences(dst, migrationModels())
  145. }
  146. // resyncPostgresSequences sets each model's id sequence to MAX(id) so the next
  147. // auto-increment INSERT won't collide with an existing row. Table names are
  148. // resolved from the models themselves (not hardcoded), so they always match the
  149. // migrated tables. The statement is a no-op for tables without an id sequence
  150. // (e.g. composite-PK tables), and idempotent on a healthy DB, so it is safe to
  151. // run both after migration and on every Postgres startup.
  152. func resyncPostgresSequences(db *gorm.DB, models []any) error {
  153. for _, m := range models {
  154. stmt := &gorm.Statement{DB: db}
  155. if err := stmt.Parse(m); err != nil {
  156. continue
  157. }
  158. t := stmt.Table
  159. // t comes from the trusted model set parsed by GORM, not user input, so
  160. // interpolating it as an identifier is safe. We ignore errors per-table.
  161. _ = db.Exec(
  162. `SELECT setval(pg_get_serial_sequence(?, 'id'), COALESCE((SELECT MAX(id) FROM "`+t+`"), 1), true)
  163. WHERE pg_get_serial_sequence(?, 'id') IS NOT NULL`,
  164. t, t,
  165. ).Error
  166. }
  167. return nil
  168. }