Browse Source

add import db api route

Hamidreza Ghavami 1 year ago
parent
commit
55c1fe26fb
2 changed files with 110 additions and 3 deletions
  1. 20 2
      web/controller/server.go
  2. 90 1
      web/service/server.go

+ 20 - 2
web/controller/server.go

@@ -41,6 +41,7 @@ func (a *ServerController) initRouter(g *gin.RouterGroup) {
 	g.POST("/logs/:count", a.getLogs)
 	g.POST("/getConfigJson", a.getConfigJson)
 	g.GET("/getDb", a.getDb)
+	g.POST("/importDB", a.importDB)
 	g.POST("/getNewX25519Cert", a.getNewX25519Cert)
 }
 
@@ -99,8 +100,8 @@ func (a *ServerController) stopXrayService(c *gin.Context) {
 		return
 	}
 	jsonMsg(c, "Xray stoped", err)
-
 }
+
 func (a *ServerController) restartXrayService(c *gin.Context) {
 	err := a.serverService.RestartXrayService()
 	if err != nil {
@@ -108,7 +109,6 @@ func (a *ServerController) restartXrayService(c *gin.Context) {
 		return
 	}
 	jsonMsg(c, "Xray restarted", err)
-
 }
 
 func (a *ServerController) getLogs(c *gin.Context) {
@@ -144,6 +144,24 @@ func (a *ServerController) getDb(c *gin.Context) {
 	c.Writer.Write(db)
 }
 
+func (a *ServerController) importDB(c *gin.Context) {
+	// Get the file from the request body
+	file, _, err := c.Request.FormFile("db")
+	if err != nil {
+		jsonMsg(c, "Error reading db file", err)
+		return
+	}
+	defer file.Close()
+	// Import it
+	err = a.serverService.ImportDB(file)
+	if err != nil {
+		jsonMsg(c, "", err)
+		return
+	}
+	a.lastGetStatusTime = time.Now()
+	jsonObj(c, "Import DB", nil)
+}
+
 func (a *ServerController) getNewX25519Cert(c *gin.Context) {
 	cert, err := a.serverService.GetNewX25519Cert()
 	if err != nil {

+ 90 - 1
web/service/server.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"io/fs"
+	"mime/multipart"
 	"net/http"
 	"os"
 	"os/exec"
@@ -14,7 +15,9 @@ import (
 	"strings"
 	"time"
 	"x-ui/config"
+	"x-ui/database"
 	"x-ui/logger"
+	"x-ui/util/common"
 	"x-ui/util/sys"
 	"x-ui/xray"
 
@@ -73,7 +76,8 @@ type Release struct {
 }
 
 type ServerService struct {
-	xrayService XrayService
+	xrayService    XrayService
+	inboundService InboundService
 }
 
 func (s *ServerService) GetStatus(lastStatus *Status) *Status {
@@ -395,6 +399,91 @@ func (s *ServerService) GetDb() ([]byte, error) {
 	return fileContents, nil
 }
 
+func (s *ServerService) ImportDB(file multipart.File) error {
+	// Check if the file is a SQLite database
+	isValidDb, err := database.IsSQLiteDB(file)
+	if err != nil {
+		return common.NewErrorf("Error checking db file format: %v", err)
+	}
+	if !isValidDb {
+		return common.NewError("Invalid db file format")
+	}
+
+	// Save the file as temporary file
+	tempPath := fmt.Sprintf("%s.temp", config.GetDBPath())
+	tempFile, err := os.Create(tempPath)
+	if err != nil {
+		return common.NewErrorf("Error creating temporary db file: %v", err)
+	}
+	defer tempFile.Close()
+
+	// Reset the file reader to the beginning
+	_, err = file.Seek(0, 0)
+	if err != nil {
+		defer os.Remove(tempPath)
+		return common.NewErrorf("Error resetting file reader: %v", err)
+	}
+
+	// Save temp file
+	_, err = io.Copy(tempFile, file)
+	if err != nil {
+		defer os.Remove(tempPath)
+		return common.NewErrorf("Error saving db: %v", err)
+	}
+
+	// Check if we can init db or not
+	err = database.InitDB(tempPath)
+	if err != nil {
+		defer os.Remove(tempPath)
+		return common.NewErrorf("Error checking db: %v", err)
+	}
+
+	// Stop Xray if its running
+	if s.xrayService.IsXrayRunning() {
+		err := s.StopXrayService()
+		if err != nil {
+			defer os.Remove(tempPath)
+			return common.NewErrorf("Failed to stop Xray: %v", err)
+		}
+	}
+
+	// Backup db for fallback
+	fallbackPath := fmt.Sprintf("%s.backup", config.GetDBPath())
+	err = os.Rename(config.GetDBPath(), fallbackPath)
+	if err != nil {
+		defer os.Remove(tempPath)
+		return common.NewErrorf("Error backup temporary db file: %v", err)
+	}
+
+	// Move temp to DB path
+	err = os.Rename(tempPath, config.GetDBPath())
+	if err != nil {
+		defer os.Remove(tempPath)
+		defer os.Rename(fallbackPath, config.GetDBPath())
+		return common.NewErrorf("Error moving db file: %v", err)
+	}
+
+	// Migrate DB
+	err = database.InitDB(config.GetDBPath())
+	if err != nil {
+		defer os.Rename(fallbackPath, config.GetDBPath())
+		return common.NewErrorf("Error migrating db: %v", err)
+	}
+	s.inboundService.MigrationRequirements()
+	s.inboundService.RemoveOrphanedTraffics()
+
+	// remove fallback file
+	defer os.Remove(fallbackPath)
+
+	// Start Xray
+	err = s.RestartXrayService()
+	if err != nil {
+		return common.NewErrorf("Imported DB but Failed to start Xray: %v", err)
+	}
+
+	return nil
+}
+
 func (s *ServerService) GetNewX25519Cert() (interface{}, error) {
 	// Run the command
 	cmd := exec.Command(xray.GetBinaryPath(), "x25519")