fix(oidc): 修正请求 token 阶段的请求包

This commit is contained in:
imlonghao
2025-08-04 21:22:20 +08:00
parent 787e8bdd80
commit 53bd83af8b
5 changed files with 15 additions and 9 deletions

View File

@@ -19,9 +19,7 @@ func OAuth(c *gin.Context) {
return
}
redirectURI := utils.GetScheme(c) + "://" + c.Request.Host + "/api/oauth_callback"
authURL, state := oauth.CurrentProvider().GetAuthorizationURL(redirectURI)
authURL, state := oauth.CurrentProvider().GetAuthorizationURL(utils.GetCallbackURL(c))
c.SetCookie("oauth_state", state, 3600, "/", "", false, true)
@@ -44,7 +42,7 @@ func OAuthCallback(c *gin.Context) {
queries[key] = values[0]
}
}
oidcUser, err := oauth.CurrentProvider().OnCallback(c.Request.Context(), state, queries)
oidcUser, err := oauth.CurrentProvider().OnCallback(c.Request.Context(), state, queries, utils.GetCallbackURL(c))
if err != nil {
c.JSON(500, gin.H{"status": "error", "error": "Failed to get user info: " + err.Error()})
return

View File

@@ -23,3 +23,9 @@ func GetScheme(c *gin.Context) string {
}
return "http"
}
func GetCallbackURL(c *gin.Context) string {
scheme := GetScheme(c)
host := c.Request.Host
return scheme + "://" + host + "/api/oauth_callback"
}

View File

@@ -8,7 +8,7 @@ type IOidcProvider interface {
GetConfiguration() Configuration
// 获取授权URL和状态
GetAuthorizationURL(redirectURI string) (string, string)
OnCallback(ctx context.Context, state string, query map[string]string) (OidcCallback, error)
OnCallback(ctx context.Context, state string, query map[string]string, callbackURI string) (OidcCallback, error)
Init() error
Destroy() error
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/komari-monitor/komari/utils"
@@ -35,7 +36,7 @@ func (g *Generic) GetAuthorizationURL(redirectURI string) (string, string) {
g.stateCache.Set(state, true, cache.DefaultExpiration)
return authURL, state
}
func (g *Generic) OnCallback(ctx context.Context, state string, query map[string]string) (factory.OidcCallback, error) {
func (g *Generic) OnCallback(ctx context.Context, state string, query map[string]string, callbackURI string) (factory.OidcCallback, error) {
code := query["code"]
// 验证state防止CSRF攻击
@@ -59,12 +60,13 @@ func (g *Generic) OnCallback(ctx context.Context, state string, query map[string
"client_id": {g.Addition.ClientId},
"client_secret": {g.Addition.ClientSecret},
"code": {code},
"redirect_uri": {callbackURI},
"grant_type": {"authorization_code"},
}
req, _ := http.NewRequest("POST", g.Addition.TokenURL, nil)
req.URL.RawQuery = data.Encode()
req, _ := http.NewRequest("POST", g.Addition.TokenURL, strings.NewReader(data.Encode()))
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {

View File

@@ -36,7 +36,7 @@ func (g *Github) GetAuthorizationURL(_ string) (string, string) {
g.stateCache.Set(state, true, cache.NoExpiration)
return authURL, state
}
func (g *Github) OnCallback(ctx context.Context, state string, query map[string]string) (factory.OidcCallback, error) {
func (g *Github) OnCallback(ctx context.Context, state string, query map[string]string, _ string) (factory.OidcCallback, error) {
code := query["code"]
// 验证state防止CSRF攻击