mirror of
				https://github.com/komari-monitor/komari.git
				synced 2025-11-03 21:43:14 +00:00 
			
		
		
		
	refactor: 重写备份上传功能
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -14,6 +14,7 @@ komari.db
 | 
			
		||||
/data
 | 
			
		||||
/utils/geoip/data
 | 
			
		||||
.vscode/
 | 
			
		||||
/backup
 | 
			
		||||
 | 
			
		||||
# Test binary, built with `go test -c`
 | 
			
		||||
*.test
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,6 @@ import (
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/komari-monitor/komari/api"
 | 
			
		||||
	"github.com/komari-monitor/komari/cmd/flags"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 只有一个备份恢复操作在进行
 | 
			
		||||
@@ -43,8 +42,14 @@ func UploadBackup(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建临时文件保存上传的zip
 | 
			
		||||
	tempFile, err := os.CreateTemp("", "backup-*.zip")
 | 
			
		||||
	// 确保data目录存在
 | 
			
		||||
	if err := os.MkdirAll("./data", 0755); err != nil {
 | 
			
		||||
		api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating data directory: %v", err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建临时文件保存上传的zip(先校验,再落地到固定位置)
 | 
			
		||||
	tempFile, err := os.CreateTemp("", "backup-upload-*.zip")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating temporary file: %v", err))
 | 
			
		||||
		return
 | 
			
		||||
@@ -61,112 +66,59 @@ func UploadBackup(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	tempFile.Close() // 关闭文件以便后续操作
 | 
			
		||||
 | 
			
		||||
	// 打开zip文件准备解压
 | 
			
		||||
	zipReader, err := zip.OpenReader(tempFilePath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error opening zip file: %v", err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer zipReader.Close()
 | 
			
		||||
 | 
			
		||||
	// 检查是否包含备份标记文件
 | 
			
		||||
	hasMarkupFile := false
 | 
			
		||||
	for _, zipFile := range zipReader.File {
 | 
			
		||||
		if zipFile.Name == "komari-backup-markup" {
 | 
			
		||||
			hasMarkupFile = true
 | 
			
		||||
	// 基础校验:检查是否包含标记文件
 | 
			
		||||
	if zr, err := zip.OpenReader(tempFilePath); err == nil {
 | 
			
		||||
		hasMarkup := false
 | 
			
		||||
		for _, f := range zr.File {
 | 
			
		||||
			if f.Name == "komari-backup-markup" {
 | 
			
		||||
				hasMarkup = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	if !hasMarkupFile {
 | 
			
		||||
		zr.Close()
 | 
			
		||||
		if !hasMarkup {
 | 
			
		||||
			api.RespondError(c, http.StatusBadRequest, "Invalid backup file: missing komari-backup-markup file")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	// 确保data目录存在
 | 
			
		||||
	if err := os.MkdirAll("./data", 0755); err != nil {
 | 
			
		||||
		api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating data directory: %v", err))
 | 
			
		||||
	} else {
 | 
			
		||||
		api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error opening zip file: %v", err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取数据库文件名
 | 
			
		||||
	dbFileName := filepath.Base(flags.DatabaseFile)
 | 
			
		||||
 | 
			
		||||
	// 解压文件
 | 
			
		||||
	for _, zipFile := range zipReader.File {
 | 
			
		||||
		// 检查文件路径是否安全(防止路径遍历攻击)
 | 
			
		||||
		filePath := zipFile.Name
 | 
			
		||||
		if strings.Contains(filePath, "..") {
 | 
			
		||||
			log.Printf("Potentially unsafe path in zip: %s, skipping", filePath)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 跳过备份标记文件
 | 
			
		||||
		if filePath == "komari-backup-markup" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 确定目标路径
 | 
			
		||||
		var destPath string
 | 
			
		||||
		if filePath == dbFileName {
 | 
			
		||||
			// 如果是数据库文件,恢复到数据库文件路径
 | 
			
		||||
			destPath = flags.DatabaseFile
 | 
			
		||||
		} else {
 | 
			
		||||
			// 其他文件恢复到data目录
 | 
			
		||||
			destPath = filepath.Join("./data", filePath)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 如果是目录,创建目录
 | 
			
		||||
		if zipFile.FileInfo().IsDir() {
 | 
			
		||||
			if err := os.MkdirAll(destPath, 0755); err != nil {
 | 
			
		||||
				log.Printf("Error creating directory %s: %v", destPath, err)
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 确保目标文件的目录存在
 | 
			
		||||
		destDir := filepath.Dir(destPath)
 | 
			
		||||
		if err := os.MkdirAll(destDir, 0755); err != nil {
 | 
			
		||||
			log.Printf("Error creating directory %s: %v", destDir, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 打开zip中的文件
 | 
			
		||||
		srcFile, err := zipFile.Open()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Printf("Error opening file from zip %s: %v", filePath, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 创建目标文件
 | 
			
		||||
		destFile, err := os.Create(destPath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			srcFile.Close()
 | 
			
		||||
			log.Printf("Error creating file %s: %v", destPath, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 复制内容
 | 
			
		||||
		_, err = io.Copy(destFile, srcFile)
 | 
			
		||||
		srcFile.Close()
 | 
			
		||||
		destFile.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Printf("Error extracting file %s: %v", destPath, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 保持原始文件的修改时间
 | 
			
		||||
		if err := os.Chtimes(destPath, zipFile.Modified, zipFile.Modified); err != nil {
 | 
			
		||||
			log.Printf("Error setting file time for %s: %v", destPath, err)
 | 
			
		||||
		}
 | 
			
		||||
	// 将校验通过的临时文件移动到固定路径 ./data/backup.zip
 | 
			
		||||
	finalPath := filepath.Join(".", "data", "backup.zip")
 | 
			
		||||
	// 如存在旧文件,先删除
 | 
			
		||||
	_ = os.Remove(finalPath)
 | 
			
		||||
	if err := os.Rename(tempFilePath, finalPath); err != nil {
 | 
			
		||||
		// fallback:拷贝
 | 
			
		||||
		in, err2 := os.Open(tempFilePath)
 | 
			
		||||
		if err2 != nil {
 | 
			
		||||
			api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error preparing backup file: %v", err))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer in.Close()
 | 
			
		||||
		out, err2 := os.Create(finalPath)
 | 
			
		||||
		if err2 != nil {
 | 
			
		||||
			api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating target backup file: %v", err2))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if _, err2 = io.Copy(out, in); err2 != nil {
 | 
			
		||||
			out.Close()
 | 
			
		||||
			api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error writing target backup file: %v", err2))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		out.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 返回:已保存备份,重启后将自动恢复
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"status":  "success",
 | 
			
		||||
		"message": "Backup restored successfully. The service will restart shortly.",
 | 
			
		||||
		"message": "Backup uploaded successfully. The service will restart and apply the backup.",
 | 
			
		||||
		"path":    "./data/backup.zip",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		log.Println("Backup restored, restarting service in 2 seconds...")
 | 
			
		||||
		log.Println("Backup uploaded, restarting service in 2 seconds to apply on startup...")
 | 
			
		||||
		time.Sleep(2 * time.Second)
 | 
			
		||||
		os.Exit(0)
 | 
			
		||||
	}()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,16 @@
 | 
			
		||||
package dbcore
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"archive/zip"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/komari-monitor/komari/cmd/flags"
 | 
			
		||||
	"github.com/komari-monitor/komari/common"
 | 
			
		||||
@@ -16,6 +22,151 @@ import (
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// zipDirectoryExcluding 将 srcDir 打包为 dstZip,exclude 是绝对路径集合需要排除
 | 
			
		||||
func zipDirectoryExcluding(srcDir, dstZip string, exclude map[string]struct{}) error {
 | 
			
		||||
	// 规范化排除路径为绝对路径
 | 
			
		||||
	normExclude := make(map[string]struct{}, len(exclude))
 | 
			
		||||
	for p := range exclude {
 | 
			
		||||
		abs, _ := filepath.Abs(p)
 | 
			
		||||
		normExclude[abs] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	out, err := os.Create(dstZip)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer out.Close()
 | 
			
		||||
 | 
			
		||||
	zw := zip.NewWriter(out)
 | 
			
		||||
	defer zw.Close()
 | 
			
		||||
 | 
			
		||||
	absSrc, _ := filepath.Abs(srcDir)
 | 
			
		||||
	walkErr := filepath.Walk(absSrc, func(path string, info os.FileInfo, err error) error {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// 排除 backup.zip 本身
 | 
			
		||||
		if _, ok := normExclude[path]; ok {
 | 
			
		||||
			if info.IsDir() {
 | 
			
		||||
				return filepath.SkipDir
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		// 计算 zip 内相对路径
 | 
			
		||||
		rel, err := filepath.Rel(absSrc, path)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// 根目录跳过
 | 
			
		||||
		if rel == "." {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		// 替换为正斜杠
 | 
			
		||||
		zipName := filepath.ToSlash(rel)
 | 
			
		||||
 | 
			
		||||
		if info.IsDir() {
 | 
			
		||||
			_, err := zw.Create(zipName + "/")
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// 普通文件
 | 
			
		||||
		fh, err := os.Open(path)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		w, err := zw.Create(zipName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fh.Close()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := io.Copy(w, fh); err != nil {
 | 
			
		||||
			fh.Close()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		fh.Close()
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if walkErr != nil {
 | 
			
		||||
		return walkErr
 | 
			
		||||
	}
 | 
			
		||||
	return zw.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// removeAllInDirExcept 删除 dir 下除 exclude 指定绝对路径外的所有文件和文件夹
 | 
			
		||||
func removeAllInDirExcept(dir string, exclude map[string]struct{}) error {
 | 
			
		||||
	absDir, err := filepath.Abs(dir)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	normExclude := make(map[string]struct{}, len(exclude))
 | 
			
		||||
	for p := range exclude {
 | 
			
		||||
		abs, _ := filepath.Abs(p)
 | 
			
		||||
		normExclude[abs] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
	entries, err := os.ReadDir(absDir)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, e := range entries {
 | 
			
		||||
		full := filepath.Join(absDir, e.Name())
 | 
			
		||||
		if _, ok := normExclude[full]; ok {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if err := os.RemoveAll(full); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// unzipToDir 将 zipPath 解压到 dstDir,包含路径遍历保护
 | 
			
		||||
func unzipToDir(zipPath, dstDir string) error {
 | 
			
		||||
	zr, err := zip.OpenReader(zipPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer zr.Close()
 | 
			
		||||
 | 
			
		||||
	if err := os.MkdirAll(dstDir, 0755); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	absDst, _ := filepath.Abs(dstDir)
 | 
			
		||||
 | 
			
		||||
	for _, f := range zr.File {
 | 
			
		||||
		// 构造目标路径并做路径遍历保护
 | 
			
		||||
		cleanName := filepath.Clean(f.Name)
 | 
			
		||||
		targetPath := filepath.Join(absDst, cleanName)
 | 
			
		||||
		if !strings.HasPrefix(targetPath, absDst+string(os.PathSeparator)) && targetPath != absDst {
 | 
			
		||||
			return fmt.Errorf("illegal file path in zip: %s", f.Name)
 | 
			
		||||
		}
 | 
			
		||||
		if f.FileInfo().IsDir() {
 | 
			
		||||
			if err := os.MkdirAll(targetPath, 0755); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		rc, err := f.Open()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		out, err := os.Create(targetPath)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			rc.Close()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := io.Copy(out, rc); err != nil {
 | 
			
		||||
			out.Close()
 | 
			
		||||
			rc.Close()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		out.Close()
 | 
			
		||||
		rc.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// mergeClientInfo 将旧版ClientInfo数据迁移到新版Client表
 | 
			
		||||
func mergeClientInfo(db *gorm.DB) {
 | 
			
		||||
	var clientInfos []common.ClientInfo
 | 
			
		||||
@@ -212,6 +363,50 @@ func GetDBInstance() *gorm.DB {
 | 
			
		||||
	once.Do(func() {
 | 
			
		||||
		var err error
 | 
			
		||||
 | 
			
		||||
		// 在数据库初始化前执行:如果存在 ./data/backup.zip,则进行恢复逻辑
 | 
			
		||||
		func() {
 | 
			
		||||
			backupZipPath := filepath.Join(".", "data", "backup.zip")
 | 
			
		||||
			if _, statErr := os.Stat(backupZipPath); statErr == nil {
 | 
			
		||||
				// 4. 把除了 ./data/backup.zip 之外的所有文件压缩到 ./backup/{time}.zip
 | 
			
		||||
				if err := os.MkdirAll("./backup", 0755); err != nil {
 | 
			
		||||
					log.Printf("[restore] failed to create backup dir: %v", err)
 | 
			
		||||
				} else {
 | 
			
		||||
					tsName := time.Now().Format("20060102-150405")
 | 
			
		||||
					bakPath := filepath.Join("./backup", fmt.Sprintf("%s.zip", tsName))
 | 
			
		||||
					if zipErr := zipDirectoryExcluding("./data", bakPath, map[string]struct{}{backupZipPath: {}}); zipErr != nil {
 | 
			
		||||
						log.Printf("[restore] failed to zip current data: %v", zipErr)
 | 
			
		||||
					} else {
 | 
			
		||||
						log.Printf("[restore] current data zipped to %s", bakPath)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 5. 删除除了 ./data/backup.zip 之外的所有文件
 | 
			
		||||
				if delErr := removeAllInDirExcept("./data", map[string]struct{}{backupZipPath: {}}); delErr != nil {
 | 
			
		||||
					log.Printf("[restore] failed to cleanup data dir: %v", delErr)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 6. 解压 ./data/backup.zip 到 ./data
 | 
			
		||||
				if unzipErr := unzipToDir(backupZipPath, "./data"); unzipErr != nil {
 | 
			
		||||
					log.Printf("[restore] failed to unzip backup into data: %v", unzipErr)
 | 
			
		||||
				} else {
 | 
			
		||||
					log.Printf("[restore] backup.zip extracted to ./data")
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 7. 删除 ./data/backup.zip
 | 
			
		||||
				if rmErr := os.Remove(backupZipPath); rmErr != nil {
 | 
			
		||||
					log.Printf("[restore] failed to remove backup.zip: %v", rmErr)
 | 
			
		||||
				} else {
 | 
			
		||||
					log.Printf("[restore] backup.zip removed")
 | 
			
		||||
				}
 | 
			
		||||
				// 8. 删除标记
 | 
			
		||||
				if rmErr := os.Remove("./data/komari-backup-markup"); rmErr != nil {
 | 
			
		||||
					log.Printf("[restore] failed to remove komari-backup-markup: %v", rmErr)
 | 
			
		||||
				} else {
 | 
			
		||||
					log.Printf("[restore] komari-backup-markup removed")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		logConfig := &gorm.Config{
 | 
			
		||||
			Logger: logger.Default.LogMode(logger.Silent),
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user