diff --git a/api/admin/exec.go b/api/admin/exec.go new file mode 100644 index 0000000..21b9976 --- /dev/null +++ b/api/admin/exec.go @@ -0,0 +1,74 @@ +package admin + +import ( + "encoding/json" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/komari-monitor/komari/database/tasks" + "github.com/komari-monitor/komari/utils" + "github.com/komari-monitor/komari/ws" +) + +// 接受数据类型: +// - command: string +// - clients: []string (客户端 UUID 列表) +func Exec(c *gin.Context) { + var req struct { + Command string `json:"command" binding:"required"` + Clients []string `json:"clients" binding:"required"` + } + var onlineClients []string + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(400, gin.H{"status": "error", "message": "Invalid request"}) + return + } + for uuid := range ws.ConnectedClients { + if contain(req.Clients, uuid) { + onlineClients = append(onlineClients, uuid) + } else { + c.JSON(400, gin.H{"status": "error", "message": "Client not connected: " + uuid}) + return + } + } + if len(onlineClients) == 0 { + c.JSON(400, gin.H{"status": "error", "message": "No clients connected"}) + return + } + taskId := utils.GenerateRandomString(16) + if err := tasks.CreateTask(taskId, onlineClients, req.Command); err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to create task: " + err.Error()}) + return + } + for _, uuid := range onlineClients { + var send struct { + Command string `json:"command"` + TaskId string `json:"task_id"` + } + send.Command = req.Command + send.TaskId = taskId + + payload, _ := json.Marshal(send) + client := ws.ConnectedClients[uuid] + if client != nil { + client.WriteMessage(websocket.TextMessage, payload) + } else { + c.JSON(400, gin.H{"status": "error", "message": "Client connection is null: " + uuid}) + return + } + } + c.JSON(200, gin.H{ + "status": "success", + "message": "Command sent to clients", + "task_id": taskId, + "clients": onlineClients, + }) +} +func contain(clients []string, uuid string) bool { + for _, client := range clients { + if client == uuid { + return true + } + } + return false +} diff --git a/api/admin/task.go b/api/admin/task.go new file mode 100644 index 0000000..b331d31 --- /dev/null +++ b/api/admin/task.go @@ -0,0 +1,107 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + "github.com/komari-monitor/komari/database/tasks" +) + +func GetTasks(c *gin.Context) { + tasks, err := tasks.GetAllTasks() + if err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to retrieve tasks: " + err.Error()}) + return + } + c.JSON(200, gin.H{"status": "success", "tasks": tasks}) +} + +func GetTaskById(c *gin.Context) { + taskId := c.Param("task_id") + if taskId == "" { + c.JSON(400, gin.H{"status": "error", "message": "Task ID is required"}) + return + } + task, err := tasks.GetTaskByTaskId(taskId) + if err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to retrieve task: " + err.Error()}) + return + } + if task == nil { + c.JSON(404, gin.H{"status": "error", "message": "Task not found"}) + return + } + c.JSON(200, gin.H{"status": "success", "task": task}) +} + +func GetTasksByClientId(c *gin.Context) { + clientId := c.Param("uuid") + if clientId == "" { + c.JSON(400, gin.H{"status": "error", "message": "Client ID is required"}) + return + } + tasks, err := tasks.GetTasksByClientId(clientId) + if err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to retrieve tasks: " + err.Error()}) + return + } + if len(tasks) == 0 { + c.JSON(404, gin.H{"status": "error", "message": "No tasks found for this client"}) + return + } + c.JSON(200, gin.H{"status": "success", "tasks": tasks}) +} + +func GetSpecificTaskResult(c *gin.Context) { + taskId := c.Param("task_id") + clientId := c.Param("uuid") + if taskId == "" || clientId == "" { + c.JSON(400, gin.H{"status": "error", "message": "Task ID and Client ID are required"}) + return + } + result, err := tasks.GetSpecificTaskResult(taskId, clientId) + if err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to retrieve task result: " + err.Error()}) + return + } + if result == nil { + c.JSON(404, gin.H{"status": "error", "message": "No result found for this task and client"}) + return + } + c.JSON(200, gin.H{"status": "success", "result": result}) +} + +// Param: task_id +func GetTaskResultsByTaskId(c *gin.Context) { + taskId := c.Param("task_id") + if taskId == "" { + c.JSON(400, gin.H{"status": "error", "message": "Task ID is required"}) + return + } + results, err := tasks.GetTaskResultsByTaskId(taskId) + if err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to retrieve task results: " + err.Error()}) + return + } + if len(results) == 0 { + c.JSON(404, gin.H{"status": "error", "message": "No results found for this task"}) + return + } + c.JSON(200, gin.H{"status": "success", "results": results}) +} + +func GetAllTaskResultByUUID(c *gin.Context) { + clientId := c.Param("uuid") + if clientId == "" { + c.JSON(400, gin.H{"status": "error", "message": "Client ID is required"}) + return + } + results, err := tasks.GetAllTasksResultByUUID(clientId) + if err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to retrieve tasks: " + err.Error()}) + return + } + if len(results) == 0 { + c.JSON(404, gin.H{"status": "error", "message": "No tasks found for this client"}) + return + } + c.JSON(200, gin.H{"status": "success", "tasks": results}) +} diff --git a/api/client/task.go b/api/client/task.go new file mode 100644 index 0000000..9358916 --- /dev/null +++ b/api/client/task.go @@ -0,0 +1,35 @@ +package client + +import ( + "time" + + "github.com/gin-gonic/gin" + "github.com/komari-monitor/komari/database/clients" + "github.com/komari-monitor/komari/database/tasks" +) + +func TaskResult(c *gin.Context) { + token := c.Query("token") + clientId, _ := clients.GetClientUUIDByToken(token) + if clientId == "" { + c.JSON(400, gin.H{"status": "error", "message": "Invalid or missing token"}) + return + } + var req struct { + TaskId string `json:"task_id" binding:"required"` + Result string `json:"result" binding:"required"` + ExitCode int `json:"exit_code"` + FinishedAt time.Time `json:"finished_at" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(400, gin.H{"status": "error", "message": "Invalid request"}) + return + } + + if err := tasks.SaveTaskResult(req.TaskId, clientId, req.Result, req.ExitCode, req.FinishedAt); err != nil { + c.JSON(500, gin.H{"status": "error", "message": "Failed to update task result: " + err.Error()}) + return + } + + c.JSON(200, gin.H{"status": "success", "message": "Task result updated successfully"}) +} diff --git a/cmd/server.go b/cmd/server.go index 4e5f1cc..21a5458 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -86,10 +86,21 @@ var ServerCmd = &cobra.Command{ tokenAuthrized.POST("/uploadBasicInfo", client.UploadBasicInfo) tokenAuthrized.POST("/report", client.UploadReport) tokenAuthrized.GET("/terminal", client.EstablishConnection) + tokenAuthrized.POST("/task/result", client.TaskResult) } adminAuthrized := r.Group("/api/admin", api.AdminAuthMiddleware()) { + // tasks + taskGroup := adminAuthrized.Group("/task") + { + taskGroup.GET("/all", admin.GetTasks) + taskGroup.POST("/exec", admin.Exec) + taskGroup.GET("/:task_id", admin.GetTaskById) + taskGroup.GET("/:task_id/result", admin.GetTaskResultsByTaskId) + taskGroup.GET("/:task_id/result/:uuid", admin.GetSpecificTaskResult) + taskGroup.GET("/client/:uuid", admin.GetTasksByClientId) + } // settings adminAuthrized.GET("/settings", admin.GetSettings) adminAuthrized.POST("/settings", admin.EditSettings) diff --git a/database/dbcore/dbcore.go b/database/dbcore/dbcore.go index 17d9ffc..8e147f1 100644 --- a/database/dbcore/dbcore.go +++ b/database/dbcore/dbcore.go @@ -169,7 +169,13 @@ func GetDBInstance() *gorm.DB { if err != nil { log.Printf("Failed to create Session table, it may already exist: %v", err) } - + err = instance.AutoMigrate( + &models.Task{}, + &models.TaskResult{}, + ) + if err != nil { + log.Printf("Failed to create Task and TaskResult table, it may already exist: %v", err) + } // 如果存在旧表,执行数据迁移 if hasOldClientInfoTable { migrateClientData(instance) diff --git a/database/models/task.go b/database/models/task.go new file mode 100644 index 0000000..c8ec913 --- /dev/null +++ b/database/models/task.go @@ -0,0 +1,16 @@ +package models + +type Task struct { + TaskId string `json:"task_id" gorm:"type:varchar(36),primaryKey,uniqueIndex:idx_tasks_task_id"` + // Clients is a JSON array of client UUIDs + Clients string `json:"clients" gorm:"type:longtext"` + Command string `json:"command" gorm:"type:text"` +} +type TaskResult struct { + TaskId string `json:"task_id" gorm:"type:varchar(36)"` + Client string `json:"client" gorm:"type:varchar(36)"` + Result string `json:"result" gorm:"type:longtext"` + ExitCode *int `json:"exit_code" gorm:"type:int,"` + FinishedAt string `json:"finished_at" gorm:"type:timestamp"` + CreatedAt string `json:"created_at" gorm:"type:timestamp"` +} diff --git a/database/tasks/tasks.go b/database/tasks/tasks.go new file mode 100644 index 0000000..1bbcbb3 --- /dev/null +++ b/database/tasks/tasks.go @@ -0,0 +1,105 @@ +package tasks + +import ( + "encoding/json" + "time" + + "github.com/komari-monitor/komari/database/dbcore" + "github.com/komari-monitor/komari/database/models" +) + +func CreateTask(taskId string, clients []string, command string) error { + db := dbcore.GetDBInstance() + // Convert clients slice to JSON string + clientsJSON, err := json.Marshal(clients) + if err != nil { + return err + } + // Create a new task in the database + task := models.Task{ + TaskId: taskId, + Clients: string(clientsJSON), + Command: command, + } + if err := db.Create(&task).Error; err != nil { + return err + } + var taskResults []models.TaskResult + for _, client := range clients { + taskResults = append(taskResults, models.TaskResult{ + TaskId: taskId, + Client: client, + Result: "", + ExitCode: nil, + FinishedAt: "", + CreatedAt: time.Now().Format(time.RFC3339), + }) + } + if len(taskResults) > 0 { + return db.Create(&taskResults).Error + } + return nil +} +func GetTaskByTaskId(taskId string) (*models.Task, error) { + var task models.Task + if err := dbcore.GetDBInstance().Where("task_id = ?", taskId).First(&task).Error; err != nil { + return nil, err + } + return &task, nil +} +func GetTasksByClientId(clientId string) ([]models.Task, error) { + var tasks []models.Task + if err := dbcore.GetDBInstance().Where("clients LIKE ?", "%"+clientId+"%").Find(&tasks).Error; err != nil { + return nil, err + } + return tasks, nil +} + +func GetSpecificTaskResult(taskId, clientId string) (*models.TaskResult, error) { + var result models.TaskResult + if err := dbcore.GetDBInstance().Where("task_id = ? AND client = ?", taskId, clientId).First(&result).Error; err != nil { + return nil, err + } + return &result, nil +} + +func GetAllTasksResultByUUID(uuid string) ([]models.TaskResult, error) { + var results []models.TaskResult + if err := dbcore.GetDBInstance().Where("client = ?", uuid).Find(&results).Error; err != nil { + return nil, err + } + return results, nil +} +func GetAllTasks() ([]models.Task, error) { + var tasks []models.Task + if err := dbcore.GetDBInstance().Find(&tasks).Error; err != nil { + return nil, err + } + return tasks, nil +} + +func GetTaskResultsByTaskId(taskId string) ([]models.TaskResult, error) { + var results []models.TaskResult + if err := dbcore.GetDBInstance().Where("task_id = ?", taskId).Find(&results).Error; err != nil { + return nil, err + } + return results, nil +} +func DeleteTaskByTaskId(taskId string) error { + return dbcore.GetDBInstance().Where("task_id = ?", taskId).Delete(&models.Task{}).Error +} + +func SaveTaskResult(taskId, clientId, result string, exitCode int, timestamp time.Time) error { + taskResult := models.TaskResult{ + TaskId: taskId, + Client: clientId, + Result: result, + ExitCode: &exitCode, + FinishedAt: timestamp.Format(time.RFC3339), + } + return dbcore.GetDBInstance().Create(&taskResult).Error +} + +func ClearTaskResultsByTimeBefore(before time.Time) error { + return dbcore.GetDBInstance().Where("created_at < ?", before.Format(time.RFC3339)).Delete(&models.TaskResult{}).Error +}