PandaWiki/backend/middleware/jwt.go

278 lines
7.4 KiB
Go

package middleware
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v5"
echoMiddleware "github.com/labstack/echo-jwt/v4"
"github.com/labstack/echo/v4"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/pg"
)
type JWTMiddleware struct {
config *config.Config
jwtMiddleware echo.MiddlewareFunc
logger *log.Logger
userAccessRepo *pg.UserAccessRepository
apiTokenRepo *pg.APITokenRepo
}
func NewJWTMiddleware(config *config.Config, logger *log.Logger, userAccessRepo *pg.UserAccessRepository, apiTokenRepo *pg.APITokenRepo) *JWTMiddleware {
jwtMiddleware := echoMiddleware.WithConfig(echoMiddleware.Config{
SigningKey: []byte(config.Auth.JWT.Secret),
ErrorHandler: func(c echo.Context, err error) error {
logger.Error("jwt auth failed", log.Error(err))
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
Success: false,
Message: "Unauthorized",
})
},
})
return &JWTMiddleware{
config: config,
jwtMiddleware: jwtMiddleware,
logger: logger.WithModule("middleware.jwt"),
userAccessRepo: userAccessRepo,
apiTokenRepo: apiTokenRepo,
}
}
func (m *JWTMiddleware) Authorize(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token := strings.TrimPrefix(authHeader, "Bearer ")
if !strings.Contains(token, ".") {
return m.validateAPIToken(c, token, next)
}
}
return m.jwtMiddleware(func(c echo.Context) error {
if userID, ok := m.MustGetUserID(c); ok {
ctx := context.WithValue(c.Request().Context(), domain.CtxAuthInfoKey, &domain.CtxAuthInfo{
IsToken: false,
Permission: consts.UserKBPermissionNull,
UserId: userID,
})
req := c.Request().WithContext(ctx)
c.SetRequest(req)
m.userAccessRepo.UpdateAccessTime(userID)
}
return next(c)
})(c)
}
}
// validateAPIToken validates API token and sets user context
func (m *JWTMiddleware) validateAPIToken(c echo.Context, token string, next echo.HandlerFunc) error {
if m.apiTokenRepo == nil {
m.logger.Debug("API token repository not available")
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
Success: false,
Message: "Unauthorized",
})
}
apiToken, err := m.apiTokenRepo.GetByTokenWithCache(c.Request().Context(), token)
if err != nil || apiToken == nil {
m.logger.Error("failed to get API token", log.Error(err))
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
Success: false,
Message: "Unauthorized",
})
}
ctx := context.WithValue(c.Request().Context(), domain.CtxAuthInfoKey, &domain.CtxAuthInfo{
IsToken: true,
Permission: apiToken.Permission,
UserId: apiToken.UserID,
KBId: apiToken.KbId,
})
req := c.Request().WithContext(ctx)
c.SetRequest(req)
return next(c)
}
func (m *JWTMiddleware) ValidateUserRole(role consts.UserRole) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authInfo := domain.GetAuthInfoFromCtx(c.Request().Context())
if authInfo == nil {
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
Success: false,
Message: "Unauthorized",
})
}
if authInfo.IsToken {
// token 视为普通用户 没有管理员相关权限
if role == consts.UserRoleAdmin {
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
Success: false,
Message: "token not support admin role",
})
}
} else {
valid, err := m.userAccessRepo.ValidateRole(authInfo.UserId, role)
if err != nil || !valid {
m.logger.Error("ValidateRole check", log.Any("user_id", authInfo.UserId), log.Any("valid", valid))
return c.JSON(http.StatusForbidden, domain.PWResponse{
Success: false,
Message: "StatusForbidden ValidateRole",
})
}
}
return next(c)
}
}
}
func (m *JWTMiddleware) ValidateKBUserPerm(perm consts.UserKBPermission) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authInfo := domain.GetAuthInfoFromCtx(c.Request().Context())
if authInfo == nil {
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
Success: false,
Message: "Unauthorized",
})
}
kbId, _ := GetKbID(c)
if authInfo.IsToken {
if authInfo.KBId != kbId {
m.logger.Error("ValidateKBUserPerm ValidateTokenKBPerm kbId", "authInfo.KBId", authInfo.KBId, "kbId", kbId)
return c.JSON(http.StatusForbidden, domain.PWResponse{
Success: false,
Message: "Unauthorized ValidateTokenKBPerm kbId",
})
}
if authInfo.Permission != consts.UserKBPermissionFullControl && authInfo.Permission != perm {
return c.JSON(http.StatusForbidden, domain.PWResponse{
Success: false,
Message: "Unauthorized ValidateTokenKBPerm",
})
}
} else {
// 正常用户请求
valid, err := m.userAccessRepo.ValidateKBPerm(kbId, authInfo.UserId, perm)
if err != nil || !valid {
if err != nil {
m.logger.Error("ValidateKBUserPerm ValidateKBPerm failed", log.Error(err))
} else {
m.logger.Info("ValidateKBUserPerm ValidateKBPerm failed", log.String("kb_id", kbId), log.String("user_id", authInfo.UserId))
}
return c.JSON(http.StatusForbidden, domain.PWResponse{
Success: false,
Message: "Unauthorized ValidateKBPerm",
})
}
}
return next(c)
}
}
}
func (m *JWTMiddleware) ValidateLicenseEdition(needEdition consts.LicenseEdition) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
edition, ok := c.Get("edition").(consts.LicenseEdition)
if !ok {
return c.JSON(http.StatusForbidden, domain.PWResponse{
Success: false,
Message: "Unauthorized ValidateLicenseEdition",
})
}
if edition < needEdition {
return c.JSON(http.StatusForbidden, domain.PWResponse{
Success: false,
Message: "Unauthorized ValidateLicenseEdition",
})
}
return next(c)
}
}
}
func (m *JWTMiddleware) MustGetUserID(c echo.Context) (string, bool) {
user, ok := c.Get("user").(*jwt.Token)
if !ok || user == nil {
return "", false
}
claims, ok := user.Claims.(jwt.MapClaims)
if !ok {
return "", false
}
id, ok := claims["id"].(string)
return id, ok
}
func GetKbID(c echo.Context) (string, error) {
switch c.Request().Method {
case http.MethodGet, http.MethodDelete:
var kbId string
if strings.Contains(c.Request().URL.Path, "knowledge_base") {
kbId = c.QueryParam("id")
if kbId != "" {
return kbId, nil
}
}
kbId = c.QueryParam("kb_id")
if kbId != "" {
return kbId, nil
}
return "", nil
case http.MethodPost, http.MethodPatch, http.MethodPut:
bodyBytes, err := io.ReadAll(c.Request().Body)
if err != nil {
return "", err
}
c.Request().Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
var m map[string]interface{}
if err := json.Unmarshal(bodyBytes, &m); err == nil {
if strings.Contains(c.Request().URL.Path, "knowledge_base") {
if id, exists := m["id"].(string); exists && id != "" {
return id, nil
}
}
if id, exists := m["kb_id"].(string); exists && id != "" {
return id, nil
}
}
return "", nil
default:
return "", nil
}
}