Prechádzať zdrojové kódy

fix(migrate-db): preserve false-valued columns in SQLite to Postgres copy

GORM struct INSERT substitutes a column default tag for Go zero-values, so disabled rows (enable=false) silently re-enabled on the destination. Copy each batch through explicit per-column maps so every value is written verbatim. Adds a regression test.
MHSanaei 11 hodín pred
rodič
commit
71cf22fa8d
2 zmenil súbory, kde vykonal 94 pridanie a 4 odobranie
  1. 19 4
      database/migrate_data.go
  2. 75 0
      database/migrate_data_test.go

+ 19 - 4
database/migrate_data.go

@@ -1,6 +1,7 @@
 package database
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"log"
@@ -109,14 +110,15 @@ func copyTable(src, dst *gorm.DB, mdl any) (int, error) {
 
 	sliceType := reflect.SliceOf(reflect.PointerTo(reflect.TypeOf(mdl).Elem()))
 
-	// Resolve primary-key columns so paging is deterministic across successive
-	// LIMIT/OFFSET reads. The model set is trusted (not user input).
 	stmt := &gorm.Statement{DB: src}
 	if err := stmt.Parse(mdl); err != nil {
 		return 0, err
 	}
 	order := strings.Join(stmt.Schema.PrimaryFieldDBNames, ", ")
+	table := stmt.Schema.Table
+	columns := stmt.Schema.DBNames
 
+	ctx := context.Background()
 	total := 0
 	for offset := 0; ; offset += batchSize {
 		batchPtr := reflect.New(sliceType)
@@ -127,11 +129,24 @@ func copyTable(src, dst *gorm.DB, mdl any) (int, error) {
 		if err := q.Find(batchPtr.Interface()).Error; err != nil {
 			return total, err
 		}
-		n := batchPtr.Elem().Len()
+		slice := batchPtr.Elem()
+		n := slice.Len()
 		if n == 0 {
 			break
 		}
-		if err := dst.CreateInBatches(batchPtr.Interface(), 200).Error; err != nil {
+
+		rows := make([]map[string]any, n)
+		for i := 0; i < n; i++ {
+			rv := reflect.Indirect(slice.Index(i))
+			row := make(map[string]any, len(columns))
+			for _, name := range columns {
+				value, _ := stmt.Schema.FieldsByDBName[name].ValueOf(ctx, rv)
+				row[name] = value
+			}
+			rows[i] = row
+		}
+
+		if err := dst.Table(table).CreateInBatches(rows, 200).Error; err != nil {
 			return total, err
 		}
 		total += n

+ 75 - 0
database/migrate_data_test.go

@@ -62,3 +62,78 @@ func TestMigrateData_CompositeKeyTableLargerThanBatch(t *testing.T) {
 		t.Fatalf("client_inbounds rows = %d, want %d", got, n)
 	}
 }
+
+func TestMigrateData_PreservesFalseDefaultedColumns(t *testing.T) {
+	dsn := os.Getenv("XUI_TEST_PG_DSN")
+	if dsn == "" {
+		t.Skip("set XUI_TEST_PG_DSN to a reachable Postgres to run this test")
+	}
+
+	srcPath := t.TempDir() + "/x-ui.db"
+	src, err := gorm.Open(sqlite.Open(srcPath), &gorm.Config{Logger: logger.Discard})
+	if err != nil {
+		t.Fatalf("open sqlite: %v", err)
+	}
+	for _, m := range migrationModels() {
+		if err := src.AutoMigrate(m); err != nil {
+			t.Fatalf("automigrate %T: %v", m, err)
+		}
+	}
+
+	if err := src.Create([]*model.ClientRecord{
+		{Email: "[email protected]"},
+		{Email: "[email protected]"},
+	}).Error; err != nil {
+		t.Fatalf("seed clients: %v", err)
+	}
+	if err := src.Model(&model.ClientRecord{}).Where("email = ?", "[email protected]").
+		Update("enable", false).Error; err != nil {
+		t.Fatalf("disable client: %v", err)
+	}
+	if err := src.Create(&model.Node{Name: "n-off", Address: "1.2.3.4", Port: 1, ApiToken: "tok"}).Error; err != nil {
+		t.Fatalf("seed node: %v", err)
+	}
+	if err := src.Model(&model.Node{}).Where("name = ?", "n-off").
+		Update("enable", false).Error; err != nil {
+		t.Fatalf("disable node: %v", err)
+	}
+	if sqlDB, err := src.DB(); err == nil {
+		sqlDB.Close()
+	}
+
+	dst, err := gorm.Open(postgres.Open(dsn), &gorm.Config{Logger: logger.Discard})
+	if err != nil {
+		t.Fatalf("open postgres: %v", err)
+	}
+	if err := dst.Migrator().DropTable(migrationModels()...); err != nil {
+		t.Fatalf("drop tables: %v", err)
+	}
+
+	if err := MigrateData(srcPath, dsn); err != nil {
+		t.Fatalf("MigrateData: %v", err)
+	}
+
+	var off model.ClientRecord
+	if err := dst.Where("email = ?", "[email protected]").First(&off).Error; err != nil {
+		t.Fatalf("load disabled client: %v", err)
+	}
+	if off.Enable {
+		t.Fatalf("disabled client re-enabled after migration (enable=%v)", off.Enable)
+	}
+
+	var on model.ClientRecord
+	if err := dst.Where("email = ?", "[email protected]").First(&on).Error; err != nil {
+		t.Fatalf("load enabled client: %v", err)
+	}
+	if !on.Enable {
+		t.Fatalf("enabled client wrongly disabled after migration")
+	}
+
+	var node model.Node
+	if err := dst.Where("name = ?", "n-off").First(&node).Error; err != nil {
+		t.Fatalf("load node: %v", err)
+	}
+	if node.Enable {
+		t.Fatalf("disabled node re-enabled after migration")
+	}
+}