mirror of https://github.com/chaitin/PandaWiki.git
126 lines
3.3 KiB
Go
126 lines
3.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"github.com/labstack/echo-contrib/session"
|
|
"github.com/labstack/echo/v4"
|
|
|
|
"github.com/chaitin/panda-wiki/domain"
|
|
"github.com/chaitin/panda-wiki/log"
|
|
"github.com/chaitin/panda-wiki/usecase"
|
|
)
|
|
|
|
type ShareAuthMiddleware struct {
|
|
logger *log.Logger
|
|
kbUsecase *usecase.KnowledgeBaseUsecase
|
|
}
|
|
|
|
func NewShareAuthMiddleware(logger *log.Logger, kbUsecase *usecase.KnowledgeBaseUsecase) *ShareAuthMiddleware {
|
|
return &ShareAuthMiddleware{
|
|
logger: logger.WithModule("middleware.share_auth"),
|
|
kbUsecase: kbUsecase,
|
|
}
|
|
}
|
|
|
|
func (h *ShareAuthMiddleware) CheckForbidden(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
kbID := c.Request().Header.Get("X-KB-ID")
|
|
if kbID == "" {
|
|
h.logger.Error("kb_id is empty")
|
|
return c.JSON(http.StatusBadRequest, domain.PWResponse{
|
|
Success: false,
|
|
Message: "kb_id is required",
|
|
})
|
|
}
|
|
|
|
kb, err := h.kbUsecase.GetKnowledgeBase(c.Request().Context(), kbID)
|
|
if err != nil {
|
|
h.logger.Error("get knowledge base failed", log.String("kb_id", kbID), log.Error(err))
|
|
return c.JSON(http.StatusInternalServerError, domain.PWResponse{
|
|
Success: false,
|
|
Message: "failed to get knowledge base detail",
|
|
})
|
|
}
|
|
|
|
if kb.AccessSettings.IsForbidden {
|
|
h.logger.Warn("access forbidden", log.String("kb_id", kbID))
|
|
return c.JSON(http.StatusForbidden, domain.PWResponse{
|
|
Success: false,
|
|
Message: "access is forbidden",
|
|
})
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|
|
|
|
func (h *ShareAuthMiddleware) Authorize(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
kbID := c.Request().Header.Get("X-KB-ID")
|
|
if kbID == "" {
|
|
h.logger.Error("kb_id is empty")
|
|
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
|
|
Success: false,
|
|
Message: "Unauthorized",
|
|
})
|
|
}
|
|
|
|
kb, err := h.kbUsecase.GetKnowledgeBase(c.Request().Context(), kbID)
|
|
if err != nil {
|
|
h.logger.Error("get knowledge base failed", log.String("kb_id", kbID), log.Error(err))
|
|
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
|
|
Success: false,
|
|
Message: "Unauthorized",
|
|
})
|
|
}
|
|
|
|
if kb.AccessSettings.IsForbidden {
|
|
h.logger.Warn("access forbidden", log.String("kb_id", kbID))
|
|
return c.JSON(http.StatusForbidden, domain.PWResponse{
|
|
Success: false,
|
|
Message: "access is forbidden",
|
|
})
|
|
}
|
|
|
|
// 未开启认证
|
|
if !kb.AccessSettings.EnterpriseAuth.Enabled && !kb.AccessSettings.SimpleAuth.Enabled {
|
|
return next(c)
|
|
}
|
|
|
|
sess, err := session.Get(domain.SessionName, c)
|
|
if err != nil {
|
|
h.logger.Error("session get failed", log.Error(err))
|
|
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
|
|
Success: false,
|
|
Message: "Unauthorized",
|
|
})
|
|
}
|
|
|
|
KbIDSess, ok := sess.Values["kb_id"].(string)
|
|
if !ok || kbID == "" || KbIDSess != kb.ID {
|
|
h.logger.Error("kb_id valid failed", log.Error(err))
|
|
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
|
|
Success: false,
|
|
Message: "Unauthorized",
|
|
})
|
|
}
|
|
|
|
// 企业认证
|
|
if kb.AccessSettings.EnterpriseAuth.Enabled {
|
|
userId, ok := sess.Values["user_id"].(uint)
|
|
if !ok || userId == 0 {
|
|
h.logger.Error("session user_id get failed", log.Error(err))
|
|
return c.JSON(http.StatusUnauthorized, domain.PWResponse{
|
|
Success: false,
|
|
Message: "Unauthorized",
|
|
})
|
|
}
|
|
c.Set("user_id", userId)
|
|
return next(c)
|
|
}
|
|
|
|
return next(c)
|
|
}
|
|
}
|