refactor: 更新Client相关逻辑,迁移旧版ClientInfo数据并优化模型

This commit is contained in:
Akizon77
2025-05-29 17:29:02 +08:00
parent 554caf7a37
commit fefdcee60d
7 changed files with 156 additions and 64 deletions

View File

@@ -2,13 +2,9 @@ package admin
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/komari-monitor/komari/common"
"github.com/komari-monitor/komari/database/clients"
"github.com/komari-monitor/komari/database/dbcore"
"github.com/komari-monitor/komari/database/models"
"github.com/komari-monitor/komari/database/records"
)
@@ -40,30 +36,15 @@ func EditClient(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
db := dbcore.GetDBInstance()
var err error
if req["token"] != "" {
err = db.Model(&models.Client{}).Where("uuid = ?", uuid).
Updates(map[string]interface{}{"token": req["token"], "updated_at": time.Now()}).Error
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": err.Error()})
return
}
if uuid == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "Invalid or missing UUID"})
return
}
allowed_fields := []string{"name", "remark", "public_remark", "weight", "price", "expired_at"}
updateFields := map[string]interface{}{
"updated_at": time.Now(),
}
for _, field := range allowed_fields {
if req[field] != nil {
updateFields[field] = req[field]
}
}
if len(updateFields) > 1 { // 大于1是因为至少包含了updated_at
if err := db.Model(&common.ClientInfo{}).Where("uuid = ?", uuid).Updates(updateFields).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": err.Error()})
return
}
err := clients.SaveClient(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"status": "success"})

View File

@@ -2,8 +2,8 @@ package admin
import (
"github.com/gin-gonic/gin"
"github.com/komari-monitor/komari/common"
"github.com/komari-monitor/komari/database/dbcore"
"github.com/komari-monitor/komari/database/models"
)
func OrderWeight(c *gin.Context) {
@@ -17,7 +17,7 @@ func OrderWeight(c *gin.Context) {
}
db := dbcore.GetDBInstance()
for uuid, weight := range req {
err := db.Model(&common.ClientInfo{}).Where("uuid = ?", uuid).Update("weight", weight).Error
err := db.Model(&models.Client{}).Where("uuid = ?", uuid).Update("weight", weight).Error
if err != nil {
c.JSON(500, gin.H{
"status": "error",

View File

@@ -19,6 +19,7 @@ func GetNodesInformation(c *gin.Context) {
clientList[i].IPv6 = ""
clientList[i].Remark = ""
clientList[i].Version = ""
clientList[i].Token = ""
}
c.JSON(200, gin.H{"status": "success", "data": clientList})

View File

@@ -29,6 +29,7 @@ type ClientConfig struct {
}
// ClientInfo stores static client information
// Deprecated: Use models.Client instead.
type ClientInfo struct {
UUID string `json:"uuid,omitempty" gorm:"type:varchar(36);primaryKey;foreignKey:ClientUUID;references:UUID;constraint:OnDelete:CASCADE"`
Name string `json:"name" gorm:"type:varchar(100);not null"`

View File

@@ -30,14 +30,10 @@ func DeleteClient(clientUuid string) error {
if err != nil {
return err
}
err = db.Delete(&common.ClientInfo{}, "uuid = ?", clientUuid).Error
if err != nil {
return err
}
return nil
}
// Decprecated: UpdateOrInsertBasicInfo is deprecated and will be removed in a future release. Use SaveClientInfo instead.
// Deprecated: UpdateOrInsertBasicInfo is deprecated and will be removed in a future release. Use SaveClientInfo instead.
func UpdateOrInsertBasicInfo(cbi common.ClientInfo) error {
db := dbcore.GetDBInstance()
updates := make(map[string]interface{})
@@ -78,10 +74,33 @@ func UpdateOrInsertBasicInfo(cbi common.ClientInfo) error {
updates["version"] = cbi.Version
updates["updated_at"] = time.Now()
err := db.Clauses(clause.OnConflict{
// 转换为更新Client表
client := models.Client{
UUID: cbi.UUID,
}
err := db.Model(&client).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "uuid"}},
DoUpdates: clause.Assignments(updates),
}).Create(&cbi).Error
}).Create(map[string]interface{}{
"uuid": cbi.UUID,
"name": cbi.Name,
"cpu_name": cbi.CpuName,
"arch": cbi.Arch,
"cpu_cores": cbi.CpuCores,
"os": cbi.OS,
"gpu_name": cbi.GpuName,
"ipv4": cbi.IPv4,
"ipv6": cbi.IPv6,
"region": cbi.Region,
"remark": cbi.Remark,
"mem_total": cbi.MemTotal,
"swap_total": cbi.SwapTotal,
"disk_total": cbi.DiskTotal,
"version": cbi.Version,
"updated_at": time.Now(),
}).Error
if err != nil {
return err
}
@@ -101,7 +120,7 @@ func SaveClientInfo(update map[string]interface{}) error {
update["updated_at"] = time.Now()
err := db.Model(&common.ClientInfo{}).Where("uuid = ?", clientUUID).Updates(update).Error
err := db.Model(&models.Client{}).Where("uuid = ?", clientUUID).Updates(update).Error
if err != nil {
return err
}
@@ -120,7 +139,7 @@ func UpdateClientConfig(config common.ClientConfig) error {
func EditClientName(clientUUID, clientName string) error {
db := dbcore.GetDBInstance()
err := db.Model(&models.Client{}).Where("uuid = ?", clientUUID).Update("client_name", clientName).Error
err := db.Model(&models.Client{}).Where("uuid = ?", clientUUID).Update("name", clientName).Error
if err != nil {
return err
}
@@ -160,6 +179,7 @@ func CreateClient() (clientUUID, token string, err error) {
client := models.Client{
UUID: clientUUID,
Token: token,
Name: "client_" + clientUUID[0:8],
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
@@ -168,14 +188,6 @@ func CreateClient() (clientUUID, token string, err error) {
if err != nil {
return "", "", err
}
clientInfo := common.ClientInfo{
UUID: clientUUID,
Name: "client_" + clientUUID[0:8],
}
err = db.Create(&clientInfo).Error
if err != nil {
return "", "", err
}
return clientUUID, token, nil
}
@@ -189,6 +201,7 @@ func CreateClientWithName(name string) (clientUUID, token string, err error) {
client := models.Client{
UUID: clientUUID,
Token: token,
Name: name,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
@@ -197,14 +210,6 @@ func CreateClientWithName(name string) (clientUUID, token string, err error) {
if err != nil {
return "", "", err
}
clientInfo := common.ClientInfo{
UUID: clientUUID,
Name: name,
}
err = db.Create(&clientInfo).Error
if err != nil {
return "", "", err
}
return clientUUID, token, nil
}
@@ -230,14 +235,14 @@ func GetClientByUUID(uuid string) (client models.Client, err error) {
}
// GetClientBasicInfo 获取指定 UUID 的客户端基本信息
func GetClientBasicInfo(uuid string) (client common.ClientInfo, err error) {
func GetClientBasicInfo(uuid string) (client models.Client, err error) {
db := dbcore.GetDBInstance()
err = db.Where("uuid = ?", uuid).First(&client).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return common.ClientInfo{}, fmt.Errorf("客户端不存在: %s", uuid)
return models.Client{}, fmt.Errorf("客户端不存在: %s", uuid)
}
return common.ClientInfo{}, err
return models.Client{}, err
}
return client, nil
}
@@ -252,7 +257,7 @@ func GetClientTokenByUUID(uuid string) (token string, err error) {
return client.Token, nil
}
func GetAllClientBasicInfo() (clients []common.ClientInfo, err error) {
func GetAllClientBasicInfo() (clients []models.Client, err error) {
db := dbcore.GetDBInstance()
err = db.Find(&clients).Error
if err != nil {
@@ -260,3 +265,24 @@ func GetAllClientBasicInfo() (clients []common.ClientInfo, err error) {
}
return clients, nil
}
func SaveClient(updates map[string]interface{}) error {
db := dbcore.GetDBInstance()
clientUUID, ok := updates["uuid"].(string)
if !ok || clientUUID == "" {
return fmt.Errorf("invalid client UUID")
}
// 确保更新的字段不为空
if len(updates) == 0 {
return fmt.Errorf("no fields to update")
}
updates["updated_at"] = time.Now()
err := db.Model(&models.Client{}).Where("uuid = ?", clientUUID).Updates(updates).Error
if err != nil {
return err
}
return nil
}

View File

@@ -15,6 +15,64 @@ import (
"gorm.io/gorm"
)
// migrateClientData 将旧版ClientInfo数据迁移到新版Client表
func migrateClientData(db *gorm.DB) {
log.Println("正在迁移旧版ClientInfo数据到新版Client表...")
// 读取所有ClientInfo记录
var clientInfos []common.ClientInfo
if err := db.Find(&clientInfos).Error; err != nil {
log.Printf("读取ClientInfo表失败: %v", err)
return
}
// 遍历每条记录并更新到Client表
for _, info := range clientInfos {
// 查找对应的Client记录
var client models.Client
if err := db.Where("uuid = ?", info.UUID).First(&client).Error; err != nil {
log.Printf("找不到UUID为%s的Client记录: %v", info.UUID, err)
continue
}
// 更新Client记录
client.Name = info.Name
client.CpuName = info.CpuName
client.Virtualization = info.Virtualization
client.Arch = info.Arch
client.CpuCores = info.CpuCores
client.OS = info.OS
client.GpuName = info.GpuName
client.IPv4 = info.IPv4
client.IPv6 = info.IPv6
client.Region = info.Region
client.Remark = info.Remark
client.PublicRemark = info.PublicRemark
client.MemTotal = info.MemTotal
client.SwapTotal = info.SwapTotal
client.DiskTotal = info.DiskTotal
client.Version = info.Version
client.Weight = info.Weight
client.Price = info.Price
client.BillingCycle = info.BillingCycle
client.ExpiredAt = info.ExpiredAt
// 保存更新后的Client记录
if err := db.Save(&client).Error; err != nil {
log.Printf("更新Client记录失败: %v", err)
continue
}
}
// 数据迁移完成后,备份并删除旧表
if err := db.Migrator().RenameTable("client_infos", "client_infos_backup"); err != nil {
log.Printf("备份ClientInfo表失败: %v", err)
return
}
log.Println("数据迁移完成旧表已备份为client_infos_backup")
}
var (
instance *gorm.DB
once sync.Once
@@ -92,14 +150,14 @@ func GetDBInstance() *gorm.DB {
default:
log.Fatalf("Unsupported database type: %s", flags.DatabaseType)
}
// 检查是否存在旧版ClientInfo表
hasOldClientInfoTable := instance.Migrator().HasTable(&common.ClientInfo{})
// 自动迁移模型
err = instance.AutoMigrate(
&models.User{},
&models.Client{},
// &models.Session{}, Error 1061 (42000): Duplicate key name 'idx_sessions_session'
&models.Record{},
&common.ClientInfo{},
&models.Config{},
)
if err != nil {
@@ -111,6 +169,11 @@ func GetDBInstance() *gorm.DB {
if err != nil {
log.Printf("Failed to create Session table, it may already exist: %v", err)
}
// 如果存在旧表,执行数据迁移
if hasOldClientInfoTable {
migrateClientData(instance)
}
})
return instance
}

View File

@@ -6,10 +6,30 @@ import (
// Client represents a registered client device
type Client struct {
UUID string `json:"uuid,omitempty" gorm:"type:varchar(36);primaryKey"`
Token string `json:"token,omitempty" gorm:"type:varchar(255);unique;not null"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
UUID string `json:"uuid,omitempty" gorm:"type:varchar(36);primaryKey"`
Token string `json:"token,omitempty" gorm:"type:varchar(255);unique;not null"`
Name string `json:"name" gorm:"type:varchar(100)"`
CpuName string `json:"cpu_name" gorm:"type:varchar(100)"`
Virtualization string `json:"virtualization" gorm:"type:varchar(50)"`
Arch string `json:"arch" gorm:"type:varchar(50)"`
CpuCores int `json:"cpu_cores" gorm:"type:int"`
OS string `json:"os" gorm:"type:varchar(100)"`
GpuName string `json:"gpu_name" gorm:"type:varchar(100)"`
IPv4 string `json:"ipv4,omitempty" gorm:"type:varchar(100)"`
IPv6 string `json:"ipv6,omitempty" gorm:"type:varchar(100)"`
Region string `json:"region" gorm:"type:varchar(100)"`
Remark string `json:"remark,omitempty" gorm:"type:longtext"`
PublicRemark string `json:"public_remark,omitempty" gorm:"type:longtext"`
MemTotal int64 `json:"mem_total" gorm:"type:bigint"`
SwapTotal int64 `json:"swap_total" gorm:"type:bigint"`
DiskTotal int64 `json:"disk_total" gorm:"type:bigint"`
Version string `json:"version,omitempty" gorm:"type:varchar(100)"`
Weight int `json:"weight" gorm:"type:int"`
Price float64 `json:"price"`
BillingCycle int `json:"billing_cycle"`
ExpiredAt time.Time `json:"expired_at" gorm:"type:timestamp"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// User represents an authenticated user