Files
komari/api/oauth.go
Gil Schneider d0eefe2dd2 feat: 新增聚合登录(QQ)功能 (#185)
* feat:新增单点登录提供商“聚合登录”;
feat:新增集合登录的登录参数设置;

* feat:移除nsmao接口支持。

* 移除调试环境生成的二进制文件
2025-08-10 15:04:49 +08:00

111 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/komari-monitor/komari/database/accounts"
"github.com/komari-monitor/komari/database/auditlog"
"github.com/komari-monitor/komari/database/config"
"github.com/komari-monitor/komari/utils"
"github.com/komari-monitor/komari/utils/oauth"
)
// /api/oauth
func OAuth(c *gin.Context) {
cfg, _ := config.Get()
if !cfg.OAuthEnabled {
c.JSON(403, gin.H{"status": "error", "error": "OAuth is not enabled"})
return
}
authURL, state := oauth.CurrentProvider().GetAuthorizationURL(utils.GetCallbackURL(c))
c.SetCookie("oauth_state", state, 3600, "/", "", false, true)
c.Redirect(302, authURL)
}
// /api/oauth_callback
func OAuthCallback(c *gin.Context) {
// 验证state防止CSRF攻击
state, _ := c.Cookie("oauth_state")
c.SetCookie("oauth_state", "", -1, "/", "", false, true)
// 获取当前OAuth提供商名称
providerName := oauth.CurrentProvider().GetName()
// 对于QQ登录由于是通过QQ聚合登录平台中转state可能会不匹配
// 但我们仍然需要验证state的存在性不能是空的
if providerName == "qq" {
if state == "" {
c.JSON(400, gin.H{"status": "error", "error": "Invalid state"})
return
}
} else {
// 对于其他提供商严格验证state匹配
if state == "" || state != c.Query("state") {
c.JSON(400, gin.H{"status": "error", "error": "Invalid state"})
return
}
}
queries := make(map[string]string)
for key, values := range c.Request.URL.Query() {
if len(values) > 0 {
queries[key] = values[0]
}
}
oidcUser, err := oauth.CurrentProvider().OnCallback(c.Request.Context(), state, queries, utils.GetCallbackURL(c))
if err != nil {
c.JSON(500, gin.H{"status": "error", "error": "Failed to get user info: " + err.Error()})
return
}
// ID作为SSO ID
sso_id := fmt.Sprintf("%s_%s", oauth.CurrentProvider().GetName(), oidcUser.UserId)
// 如果cookie中有binding_external_account说明是绑定外部账号
// 否则是登录
uuid, _ := c.Cookie("binding_external_account")
c.SetCookie("binding_external_account", "", -1, "/", "", false, true)
if uuid != "" {
// 绑定外部账号
session, _ := c.Cookie("session_token")
user, err := accounts.GetUserBySession(session)
if err != nil || user.UUID != uuid {
c.JSON(500, gin.H{"status": "error", "message": "Binding failed"})
return
}
err = accounts.BindingExternalAccount(user.UUID, sso_id)
if err != nil {
c.JSON(500, gin.H{"status": "error", "message": "Binding failed"})
return
}
auditlog.Log(c.ClientIP(), user.UUID, "bound external account (OAuth)"+fmt.Sprintf(",sso_id: %s", sso_id), "login")
c.Redirect(302, "/manage")
return
}
// 尝试获取用户
user, err := accounts.GetUserBySSO(sso_id)
if err != nil {
c.JSON(401, gin.H{
"status": "error",
"message": "please log in and bind your external account first.",
})
return
}
// 创建会话
session, err := accounts.CreateSession(user.UUID, 2592000, c.Request.UserAgent(), c.ClientIP(), "oauth")
if err != nil {
c.JSON(500, gin.H{"status": "error", "message": err.Error()})
return
}
// 设置cookie并返回
c.SetCookie("session_token", session, 2592000, "/", "", false, true)
auditlog.Log(c.ClientIP(), user.UUID, "logged in (OAuth)", "login")
c.Redirect(302, "/admin")
}