diff --git a/.gitignore b/.gitignore index ea35d49..1b7523e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ komari.db /data /utils/geoip/data .vscode/ +/backup # Test binary, built with `go test -c` *.test diff --git a/api/admin/upload.go b/api/admin/upload.go index 213fe34..8f8a673 100644 --- a/api/admin/upload.go +++ b/api/admin/upload.go @@ -14,7 +14,6 @@ import ( "github.com/gin-gonic/gin" "github.com/komari-monitor/komari/api" - "github.com/komari-monitor/komari/cmd/flags" ) // 只有一个备份恢复操作在进行 @@ -43,8 +42,14 @@ func UploadBackup(c *gin.Context) { return } - // 创建临时文件保存上传的zip - tempFile, err := os.CreateTemp("", "backup-*.zip") + // 确保data目录存在 + if err := os.MkdirAll("./data", 0755); err != nil { + api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating data directory: %v", err)) + return + } + + // 创建临时文件保存上传的zip(先校验,再落地到固定位置) + tempFile, err := os.CreateTemp("", "backup-upload-*.zip") if err != nil { api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating temporary file: %v", err)) return @@ -61,112 +66,59 @@ func UploadBackup(c *gin.Context) { } tempFile.Close() // 关闭文件以便后续操作 - // 打开zip文件准备解压 - zipReader, err := zip.OpenReader(tempFilePath) - if err != nil { + // 基础校验:检查是否包含标记文件 + if zr, err := zip.OpenReader(tempFilePath); err == nil { + hasMarkup := false + for _, f := range zr.File { + if f.Name == "komari-backup-markup" { + hasMarkup = true + break + } + } + zr.Close() + if !hasMarkup { + api.RespondError(c, http.StatusBadRequest, "Invalid backup file: missing komari-backup-markup file") + return + } + } else { api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error opening zip file: %v", err)) return } - defer zipReader.Close() - // 检查是否包含备份标记文件 - hasMarkupFile := false - for _, zipFile := range zipReader.File { - if zipFile.Name == "komari-backup-markup" { - hasMarkupFile = true - break - } - } - if !hasMarkupFile { - api.RespondError(c, http.StatusBadRequest, "Invalid backup file: missing komari-backup-markup file") - return - } - - // 确保data目录存在 - if err := os.MkdirAll("./data", 0755); err != nil { - api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating data directory: %v", err)) - return - } - - // 获取数据库文件名 - dbFileName := filepath.Base(flags.DatabaseFile) - - // 解压文件 - for _, zipFile := range zipReader.File { - // 检查文件路径是否安全(防止路径遍历攻击) - filePath := zipFile.Name - if strings.Contains(filePath, "..") { - log.Printf("Potentially unsafe path in zip: %s, skipping", filePath) - continue - } - - // 跳过备份标记文件 - if filePath == "komari-backup-markup" { - continue - } - - // 确定目标路径 - var destPath string - if filePath == dbFileName { - // 如果是数据库文件,恢复到数据库文件路径 - destPath = flags.DatabaseFile - } else { - // 其他文件恢复到data目录 - destPath = filepath.Join("./data", filePath) - } - - // 如果是目录,创建目录 - if zipFile.FileInfo().IsDir() { - if err := os.MkdirAll(destPath, 0755); err != nil { - log.Printf("Error creating directory %s: %v", destPath, err) - } - continue - } - - // 确保目标文件的目录存在 - destDir := filepath.Dir(destPath) - if err := os.MkdirAll(destDir, 0755); err != nil { - log.Printf("Error creating directory %s: %v", destDir, err) - continue - } - - // 打开zip中的文件 - srcFile, err := zipFile.Open() - if err != nil { - log.Printf("Error opening file from zip %s: %v", filePath, err) - continue - } - - // 创建目标文件 - destFile, err := os.Create(destPath) - if err != nil { - srcFile.Close() - log.Printf("Error creating file %s: %v", destPath, err) - continue - } - - // 复制内容 - _, err = io.Copy(destFile, srcFile) - srcFile.Close() - destFile.Close() - if err != nil { - log.Printf("Error extracting file %s: %v", destPath, err) - continue - } - - // 保持原始文件的修改时间 - if err := os.Chtimes(destPath, zipFile.Modified, zipFile.Modified); err != nil { - log.Printf("Error setting file time for %s: %v", destPath, err) + // 将校验通过的临时文件移动到固定路径 ./data/backup.zip + finalPath := filepath.Join(".", "data", "backup.zip") + // 如存在旧文件,先删除 + _ = os.Remove(finalPath) + if err := os.Rename(tempFilePath, finalPath); err != nil { + // fallback:拷贝 + in, err2 := os.Open(tempFilePath) + if err2 != nil { + api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error preparing backup file: %v", err)) + return } + defer in.Close() + out, err2 := os.Create(finalPath) + if err2 != nil { + api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error creating target backup file: %v", err2)) + return + } + if _, err2 = io.Copy(out, in); err2 != nil { + out.Close() + api.RespondError(c, http.StatusInternalServerError, fmt.Sprintf("Error writing target backup file: %v", err2)) + return + } + out.Close() } + // 返回:已保存备份,重启后将自动恢复 c.JSON(http.StatusOK, gin.H{ "status": "success", - "message": "Backup restored successfully. The service will restart shortly.", + "message": "Backup uploaded successfully. The service will restart and apply the backup.", + "path": "./data/backup.zip", }) go func() { - log.Println("Backup restored, restarting service in 2 seconds...") + log.Println("Backup uploaded, restarting service in 2 seconds to apply on startup...") time.Sleep(2 * time.Second) os.Exit(0) }() diff --git a/database/dbcore/dbcore.go b/database/dbcore/dbcore.go index c902bf5..7f6e219 100644 --- a/database/dbcore/dbcore.go +++ b/database/dbcore/dbcore.go @@ -1,10 +1,16 @@ package dbcore import ( + "archive/zip" "encoding/json" "fmt" + "io" "log" + "os" + "path/filepath" + "strings" "sync" + "time" "github.com/komari-monitor/komari/cmd/flags" "github.com/komari-monitor/komari/common" @@ -16,6 +22,151 @@ import ( "gorm.io/gorm/logger" ) +// zipDirectoryExcluding 将 srcDir 打包为 dstZip,exclude 是绝对路径集合需要排除 +func zipDirectoryExcluding(srcDir, dstZip string, exclude map[string]struct{}) error { + // 规范化排除路径为绝对路径 + normExclude := make(map[string]struct{}, len(exclude)) + for p := range exclude { + abs, _ := filepath.Abs(p) + normExclude[abs] = struct{}{} + } + + out, err := os.Create(dstZip) + if err != nil { + return err + } + defer out.Close() + + zw := zip.NewWriter(out) + defer zw.Close() + + absSrc, _ := filepath.Abs(srcDir) + walkErr := filepath.Walk(absSrc, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + // 排除 backup.zip 本身 + if _, ok := normExclude[path]; ok { + if info.IsDir() { + return filepath.SkipDir + } + return nil + } + // 计算 zip 内相对路径 + rel, err := filepath.Rel(absSrc, path) + if err != nil { + return err + } + // 根目录跳过 + if rel == "." { + return nil + } + // 替换为正斜杠 + zipName := filepath.ToSlash(rel) + + if info.IsDir() { + _, err := zw.Create(zipName + "/") + return err + } + // 普通文件 + fh, err := os.Open(path) + if err != nil { + return err + } + w, err := zw.Create(zipName) + if err != nil { + fh.Close() + return err + } + if _, err := io.Copy(w, fh); err != nil { + fh.Close() + return err + } + fh.Close() + return nil + }) + if walkErr != nil { + return walkErr + } + return zw.Close() +} + +// removeAllInDirExcept 删除 dir 下除 exclude 指定绝对路径外的所有文件和文件夹 +func removeAllInDirExcept(dir string, exclude map[string]struct{}) error { + absDir, err := filepath.Abs(dir) + if err != nil { + return err + } + normExclude := make(map[string]struct{}, len(exclude)) + for p := range exclude { + abs, _ := filepath.Abs(p) + normExclude[abs] = struct{}{} + } + entries, err := os.ReadDir(absDir) + if err != nil { + return err + } + for _, e := range entries { + full := filepath.Join(absDir, e.Name()) + if _, ok := normExclude[full]; ok { + continue + } + if err := os.RemoveAll(full); err != nil { + return err + } + } + return nil +} + +// unzipToDir 将 zipPath 解压到 dstDir,包含路径遍历保护 +func unzipToDir(zipPath, dstDir string) error { + zr, err := zip.OpenReader(zipPath) + if err != nil { + return err + } + defer zr.Close() + + if err := os.MkdirAll(dstDir, 0755); err != nil { + return err + } + absDst, _ := filepath.Abs(dstDir) + + for _, f := range zr.File { + // 构造目标路径并做路径遍历保护 + cleanName := filepath.Clean(f.Name) + targetPath := filepath.Join(absDst, cleanName) + if !strings.HasPrefix(targetPath, absDst+string(os.PathSeparator)) && targetPath != absDst { + return fmt.Errorf("illegal file path in zip: %s", f.Name) + } + if f.FileInfo().IsDir() { + if err := os.MkdirAll(targetPath, 0755); err != nil { + return err + } + continue + } + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err + } + rc, err := f.Open() + if err != nil { + return err + } + out, err := os.Create(targetPath) + if err != nil { + rc.Close() + return err + } + if _, err := io.Copy(out, rc); err != nil { + out.Close() + rc.Close() + return err + } + out.Close() + rc.Close() + } + return nil +} + // mergeClientInfo 将旧版ClientInfo数据迁移到新版Client表 func mergeClientInfo(db *gorm.DB) { var clientInfos []common.ClientInfo @@ -212,6 +363,50 @@ func GetDBInstance() *gorm.DB { once.Do(func() { var err error + // 在数据库初始化前执行:如果存在 ./data/backup.zip,则进行恢复逻辑 + func() { + backupZipPath := filepath.Join(".", "data", "backup.zip") + if _, statErr := os.Stat(backupZipPath); statErr == nil { + // 4. 把除了 ./data/backup.zip 之外的所有文件压缩到 ./backup/{time}.zip + if err := os.MkdirAll("./backup", 0755); err != nil { + log.Printf("[restore] failed to create backup dir: %v", err) + } else { + tsName := time.Now().Format("20060102-150405") + bakPath := filepath.Join("./backup", fmt.Sprintf("%s.zip", tsName)) + if zipErr := zipDirectoryExcluding("./data", bakPath, map[string]struct{}{backupZipPath: {}}); zipErr != nil { + log.Printf("[restore] failed to zip current data: %v", zipErr) + } else { + log.Printf("[restore] current data zipped to %s", bakPath) + } + } + + // 5. 删除除了 ./data/backup.zip 之外的所有文件 + if delErr := removeAllInDirExcept("./data", map[string]struct{}{backupZipPath: {}}); delErr != nil { + log.Printf("[restore] failed to cleanup data dir: %v", delErr) + } + + // 6. 解压 ./data/backup.zip 到 ./data + if unzipErr := unzipToDir(backupZipPath, "./data"); unzipErr != nil { + log.Printf("[restore] failed to unzip backup into data: %v", unzipErr) + } else { + log.Printf("[restore] backup.zip extracted to ./data") + } + + // 7. 删除 ./data/backup.zip + if rmErr := os.Remove(backupZipPath); rmErr != nil { + log.Printf("[restore] failed to remove backup.zip: %v", rmErr) + } else { + log.Printf("[restore] backup.zip removed") + } + // 8. 删除标记 + if rmErr := os.Remove("./data/komari-backup-markup"); rmErr != nil { + log.Printf("[restore] failed to remove komari-backup-markup: %v", rmErr) + } else { + log.Printf("[restore] komari-backup-markup removed") + } + } + }() + logConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }