PandaWiki/backend/usecase/node.go

690 lines
19 KiB
Go

package usecase
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"github.com/gomarkdown/markdown"
"github.com/gomarkdown/markdown/html"
"github.com/gomarkdown/markdown/parser"
"github.com/microcosm-cc/bluemonday"
"github.com/samber/lo"
"gorm.io/gorm"
v1 "github.com/chaitin/panda-wiki/api/node/v1"
shareV1 "github.com/chaitin/panda-wiki/api/share/v1"
"github.com/chaitin/panda-wiki/consts"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/mq"
"github.com/chaitin/panda-wiki/repo/pg"
"github.com/chaitin/panda-wiki/store/rag"
"github.com/chaitin/panda-wiki/store/s3"
"github.com/chaitin/panda-wiki/utils"
)
type NodeUsecase struct {
nodeRepo *pg.NodeRepository
appRepo *pg.AppRepository
ragRepo *mq.RAGRepository
kbRepo *pg.KnowledgeBaseRepository
modelRepo *pg.ModelRepository
userRepo *pg.UserRepository
authRepo *pg.AuthRepo
llmUsecase *LLMUsecase
logger *log.Logger
s3Client *s3.MinioClient
rAGService rag.RAGService
modelUsecase *ModelUsecase
}
func NewNodeUsecase(
nodeRepo *pg.NodeRepository,
appRepo *pg.AppRepository,
ragRepo *mq.RAGRepository,
userRepo *pg.UserRepository,
kbRepo *pg.KnowledgeBaseRepository,
llmUsecase *LLMUsecase,
ragService rag.RAGService,
logger *log.Logger,
s3Client *s3.MinioClient,
modelRepo *pg.ModelRepository,
authRepo *pg.AuthRepo,
modelUsecase *ModelUsecase,
) *NodeUsecase {
return &NodeUsecase{
nodeRepo: nodeRepo,
rAGService: ragService,
appRepo: appRepo,
ragRepo: ragRepo,
kbRepo: kbRepo,
authRepo: authRepo,
userRepo: userRepo,
llmUsecase: llmUsecase,
modelRepo: modelRepo,
logger: logger.WithModule("usecase.node"),
s3Client: s3Client,
modelUsecase: modelUsecase,
}
}
const ragSyncChunkSize = 100
func (u *NodeUsecase) Create(ctx context.Context, req *domain.CreateNodeReq, userId string) (string, error) {
nodeID, err := u.nodeRepo.Create(ctx, req, userId)
if err != nil {
return "", err
}
return nodeID, nil
}
func (u *NodeUsecase) GetList(ctx context.Context, req *domain.GetNodeListReq) ([]*domain.NodeListItemResp, error) {
nodes, err := u.nodeRepo.GetList(ctx, req)
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
publisherMap, err := u.nodeRepo.GetNodeReleasePublisherMap(ctx, req.KBID)
if err != nil {
return nil, err
}
for _, node := range nodes {
if publisherID, exists := publisherMap[node.ID]; exists {
node.PublisherId = publisherID
}
}
return nodes, nil
}
func (u *NodeUsecase) GetNodeByKBID(ctx context.Context, id, kbId, format string) (*v1.NodeDetailResp, error) {
node, err := u.nodeRepo.GetByID(ctx, id, kbId)
if err != nil {
return nil, err
}
nodeRelease, err := u.nodeRepo.GetLatestNodeReleaseWithPublishAccount(ctx, node.ID)
if err != nil {
return nil, err
}
if nodeRelease != nil {
node.PublisherId = nodeRelease.PublisherId
node.PublisherAccount = nodeRelease.PublisherAccount
}
if node.Meta.ContentType == domain.ContentTypeMD {
return node, nil
}
if format != "raw" {
if !utils.IsLikelyHTML(node.Content) {
node.Content = u.convertMDToHTML(node.Content)
}
}
return node, nil
}
func (u *NodeUsecase) NodeAction(ctx context.Context, req *domain.NodeActionReq) error {
switch req.Action {
case "delete":
docIDs, err := u.nodeRepo.Delete(ctx, req.KBID, req.IDs)
if err != nil {
return err
}
nodeVectorContentRequests := make([]*domain.NodeReleaseVectorRequest, 0)
for _, docID := range docIDs {
nodeVectorContentRequests = append(nodeVectorContentRequests, &domain.NodeReleaseVectorRequest{
KBID: req.KBID,
DocID: docID,
Action: "delete",
})
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeVectorContentRequests); err != nil {
return err
}
}
return nil
}
func (u *NodeUsecase) Update(ctx context.Context, req *domain.UpdateNodeReq, userId string) error {
err := u.nodeRepo.UpdateNodeContent(ctx, req, userId)
if err != nil {
return err
}
return nil
}
func (u *NodeUsecase) ValidateNodePerm(ctx context.Context, kbID, nodeId string, authId uint) *domain.PWResponseErrCode {
node, err := u.nodeRepo.GetNodeReleaseDetailByKBIDAndID(ctx, kbID, nodeId)
if err != nil {
return &domain.ErrCodeNotFound
}
switch node.Permissions.Visitable {
case consts.NodeAccessPermOpen:
return nil
case consts.NodeAccessPermClosed:
return &domain.ErrCodePermissionDenied
case consts.NodeAccessPermPartial:
authGroups, err := u.authRepo.GetAuthGroupWithParentsByAuthId(ctx, authId)
if err != nil {
return &domain.ErrCodeInternalError
}
authGroupIds := lo.Map(authGroups, func(v domain.AuthGroup, i int) uint {
return v.ID
})
nodeGroupIds := make([]string, 0)
if len(authGroupIds) != 0 {
nodeGroups, err := u.nodeRepo.GetNodeGroupsByGroupIdsPerm(ctx, authGroupIds, consts.NodePermNameVisitable)
if err != nil {
return &domain.ErrCodeInternalError
}
nodeGroupIds = lo.Map(nodeGroups, func(v domain.NodeAuthGroup, i int) string {
return v.NodeID
})
}
if !slices.Contains(nodeGroupIds, nodeId) {
u.logger.Error("ValidateNodePerm failed", log.Any("node_group_ids", nodeGroupIds), log.Any("node_id", nodeId))
return &domain.ErrCodePermissionDenied
}
default:
return &domain.ErrCodeInternalError
}
return nil
}
func (u *NodeUsecase) GetNodeReleaseDetailByKBIDAndID(ctx context.Context, kbID, nodeId, format string) (*shareV1.ShareNodeDetailResp, error) {
node, err := u.nodeRepo.GetNodeReleaseDetailByKBIDAndID(ctx, kbID, nodeId)
if err != nil {
return nil, err
}
userMap, err := u.userRepo.GetUsersAccountMap(ctx)
if err != nil {
return nil, err
}
if account, ok := userMap[node.CreatorId]; ok {
node.CreatorAccount = account
}
if account, ok := userMap[node.EditorId]; ok {
node.EditorAccount = account
}
if account, ok := userMap[node.PublisherId]; ok {
node.PublisherAccount = account
}
if node.Meta.ContentType == domain.ContentTypeMD {
return node, nil
}
// just for info
if format != "raw" {
if !utils.IsLikelyHTML(node.Content) {
node.Content = u.convertMDToHTML(node.Content)
}
}
return node, nil
}
func (u *NodeUsecase) MoveNode(ctx context.Context, req *domain.MoveNodeReq) error {
return u.nodeRepo.MoveNodeBetween(ctx, req.ID, req.ParentID, req.PrevID, req.NextID, req.KbID)
}
func (u *NodeUsecase) SummaryNode(ctx context.Context, req *domain.NodeSummaryReq) (string, error) {
model, err := u.modelUsecase.GetChatModel(ctx)
if err != nil {
if err == gorm.ErrRecordNotFound {
return "", domain.ErrModelNotConfigured
}
return "", err
}
if len(req.IDs) == 1 {
node, err := u.nodeRepo.GetNodeByID(ctx, req.IDs[0])
if err != nil {
return "", fmt.Errorf("get latest node release failed: %w", err)
}
summary, err := u.llmUsecase.SummaryNode(ctx, model, node.Name, node.Content)
if err != nil {
return "", fmt.Errorf("summary node failed: %w", err)
}
return summary, nil
} else {
// async create node summary
nodeVectorContentRequests := make([]*domain.NodeReleaseVectorRequest, 0)
for _, id := range req.IDs {
nodeVectorContentRequests = append(nodeVectorContentRequests, &domain.NodeReleaseVectorRequest{
KBID: req.KBID,
NodeID: id,
Action: "summary",
})
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeVectorContentRequests); err != nil {
return "", err
}
}
return "", nil
}
func (u *NodeUsecase) GetRecommendNodeList(ctx context.Context, req *domain.GetRecommendNodeListReq) ([]*domain.RecommendNodeListResp, error) {
// get latest kb release
kbRelease, err := u.kbRepo.GetLatestRelease(ctx, req.KBID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
nodes, err := u.nodeRepo.GetRecommendNodeListByIDs(ctx, req.KBID, kbRelease.ID, req.NodeIDs)
if err != nil {
return nil, err
}
if len(nodes) > 0 {
// sort nodes by req.NodeIDs order
nodesMap := lo.SliceToMap(nodes, func(item *domain.RecommendNodeListResp) (string, *domain.RecommendNodeListResp) {
return item.ID, item
})
nodes = make([]*domain.RecommendNodeListResp, 0)
for _, id := range req.NodeIDs {
if node, ok := nodesMap[id]; ok {
nodes = append(nodes, node)
}
}
// get folder nodes
folderNodeIds := lo.Filter(nodes, func(item *domain.RecommendNodeListResp, _ int) bool {
return item.Type == domain.NodeTypeFolder
})
if len(folderNodeIds) > 0 {
parentIDNodeMap, err := u.nodeRepo.GetRecommendNodeListByParentIDs(ctx, req.KBID, kbRelease.ID, lo.Map(folderNodeIds, func(item *domain.RecommendNodeListResp, _ int) string {
return item.ID
}))
if err != nil {
return nil, err
}
for _, node := range nodes {
if parentNodes, ok := parentIDNodeMap[node.ID]; ok {
node.RecommendNodes = parentNodes
}
}
}
return nodes, nil
}
return nil, nil
}
func (u *NodeUsecase) BatchMoveNode(ctx context.Context, req *domain.BatchMoveReq) error {
return u.nodeRepo.BatchMove(ctx, req)
}
func (u *NodeUsecase) convertMDToHTML(mdStr string) string {
extensions := parser.CommonExtensions & ^parser.Autolink & ^parser.MathJax
p := parser.NewWithExtensions(extensions)
doc := p.Parse([]byte(mdStr))
// create HTML renderer with extensions
htmlFlags := html.CommonFlags | html.HrefTargetBlank
opts := html.RendererOptions{Flags: htmlFlags}
renderer := html.NewRenderer(opts)
maybeUnsafeHTML := markdown.Render(doc, renderer)
html := bluemonday.UGCPolicy().SanitizeBytes(maybeUnsafeHTML)
return string(html)
}
func (u *NodeUsecase) GetNodeReleaseListByKBID(ctx context.Context, kbID string, authId uint) ([]*domain.ShareNodeListItemResp, error) {
nodes, err := u.nodeRepo.GetNodeReleaseListByKBID(ctx, kbID)
if err != nil {
return nil, err
}
nodeGroupIds, err := u.GetNodeIdsByAuthId(ctx, authId, consts.NodePermNameVisible)
if err != nil {
return nil, err
}
items := make([]*domain.ShareNodeListItemResp, 0)
for i, node := range nodes {
switch node.Permissions.Visible {
case consts.NodeAccessPermOpen:
items = append(items, nodes[i])
case consts.NodeAccessPermPartial:
if slices.Contains(nodeGroupIds, node.ID) {
items = append(items, nodes[i])
}
}
}
return items, nil
}
func (u *NodeUsecase) GetNodeReleaseListByParentID(ctx context.Context, kbID, parentID string, authId uint) ([]*domain.ShareNodeDetailItem, error) {
// 一次性查询所有节点
allNodes, err := u.nodeRepo.GetNodeReleaseListByKBID(ctx, kbID)
if err != nil {
return nil, err
}
nodeGroupIds, err := u.GetNodeIdsByAuthId(ctx, authId, consts.NodePermNameVisible)
if err != nil {
return nil, err
}
// 先过滤权限
visibleNodes := make([]*domain.ShareNodeListItemResp, 0)
for i, node := range allNodes {
switch node.Permissions.Visible {
case consts.NodeAccessPermOpen:
visibleNodes = append(visibleNodes, allNodes[i])
case consts.NodeAccessPermPartial:
if slices.Contains(nodeGroupIds, node.ID) {
visibleNodes = append(visibleNodes, allNodes[i])
}
}
}
// 构建父子关系映射
childrenMap := make(map[string][]*domain.ShareNodeListItemResp)
for _, node := range visibleNodes {
childrenMap[node.ParentID] = append(childrenMap[node.ParentID], node)
}
// 构建树结构
result := u.buildNodeTree(parentID, childrenMap)
return result, nil
}
// buildNodeTree 递归构建节点树结构
func (u *NodeUsecase) buildNodeTree(parentID string, childrenMap map[string][]*domain.ShareNodeListItemResp) []*domain.ShareNodeDetailItem {
children := childrenMap[parentID]
result := make([]*domain.ShareNodeDetailItem, 0, len(children))
for _, child := range children {
node := &domain.ShareNodeDetailItem{
ID: child.ID,
Name: child.Name,
Type: child.Type,
ParentID: child.ParentID,
Position: child.Position,
Meta: child.Meta,
Emoji: child.Emoji,
UpdatedAt: child.UpdatedAt,
Children: make([]*domain.ShareNodeDetailItem, 0),
}
// 如果是文件夹,递归构建其子节点
if child.Type == domain.NodeTypeFolder {
childNodes := u.buildNodeTree(child.ID, childrenMap)
if len(childNodes) > 0 {
node.Children = append(node.Children, childNodes...)
}
}
result = append(result, node)
}
return result
}
func (u *NodeUsecase) GetNodeIdsByAuthId(ctx context.Context, authId uint, PermName consts.NodePermName) ([]string, error) {
authGroups, err := u.authRepo.GetAuthGroupWithParentsByAuthId(ctx, authId)
if err != nil {
return nil, err
}
authGroupIds := lo.Map(authGroups, func(v domain.AuthGroup, i int) uint {
return v.ID
})
nodeGroupIds := make([]string, 0)
if len(authGroupIds) != 0 {
nodeGroups, err := u.nodeRepo.GetNodeGroupsByGroupIdsPerm(ctx, authGroupIds, PermName)
if err != nil {
return nil, err
}
nodeGroupIds = lo.Map(nodeGroups, func(v domain.NodeAuthGroup, i int) string {
return v.NodeID
})
}
return nodeGroupIds, nil
}
func (u *NodeUsecase) GetNodePermissionsByID(ctx context.Context, id, kbID string) (*v1.NodePermissionResp, error) {
node, err := u.nodeRepo.GetByID(ctx, id, kbID)
if err != nil {
return nil, err
}
resp := &v1.NodePermissionResp{
ID: node.ID,
Permissions: node.Permissions,
AnswerableGroups: make([]domain.NodeGroupDetail, 0),
VisitableGroups: make([]domain.NodeGroupDetail, 0),
VisibleGroups: make([]domain.NodeGroupDetail, 0),
}
nodeGroupList, err := u.nodeRepo.GetNodeGroupByNodeId(ctx, node.ID)
if err != nil {
return nil, err
}
for i, nodeGroup := range nodeGroupList {
switch nodeGroup.Perm {
case consts.NodePermNameAnswerable:
resp.AnswerableGroups = append(resp.AnswerableGroups, nodeGroupList[i])
case consts.NodePermNameVisitable:
resp.VisitableGroups = append(resp.VisitableGroups, nodeGroupList[i])
case consts.NodePermNameVisible:
resp.VisibleGroups = append(resp.VisibleGroups, nodeGroupList[i])
}
}
return resp, err
}
func (u *NodeUsecase) ValidateNodePermissionsEdit(req v1.NodePermissionEditReq, edition consts.LicenseEdition) error {
if !slices.Contains([]consts.LicenseEdition{consts.LicenseEditionBusiness, consts.LicenseEditionEnterprise}, edition) {
if req.Permissions.Answerable == consts.NodeAccessPermPartial || req.Permissions.Visitable == consts.NodeAccessPermPartial || req.Permissions.Visible == consts.NodeAccessPermPartial {
return domain.ErrPermissionDenied
}
if req.AnswerableGroups != nil || req.VisitableGroups != nil || req.VisibleGroups != nil {
return domain.ErrPermissionDenied
}
}
return nil
}
func (u *NodeUsecase) NodePermissionsEdit(ctx context.Context, req v1.NodePermissionEditReq) error {
if req.Permissions != nil {
updateMap := map[string]interface{}{
"permissions": req.Permissions,
}
if err := u.nodeRepo.UpdateNodesByKbID(ctx, req.IDs, req.KbId, updateMap); err != nil {
return err
}
}
nodeReleases, err := u.nodeRepo.GetLatestNodeReleaseByNodeIDs(ctx, req.KbId, req.IDs)
if err != nil {
return fmt.Errorf("get latest node release failed: %w", err)
}
if len(nodeReleases) > 0 {
nodeVectorContentRequests := make([]*domain.NodeReleaseVectorRequest, 0)
var groupIds []int
switch req.Permissions.Answerable {
case consts.NodeAccessPermOpen:
groupIds = nil
case consts.NodeAccessPermPartial:
groupIds = *req.AnswerableGroups
case consts.NodeAccessPermClosed:
groupIds = make([]int, 0)
}
for _, nodeRelease := range nodeReleases {
if nodeRelease.DocID == "" {
continue
}
nodeVectorContentRequests = append(nodeVectorContentRequests, &domain.NodeReleaseVectorRequest{
KBID: req.KbId,
DocID: nodeRelease.DocID,
Action: "update_group_ids",
GroupIds: groupIds,
})
}
if len(nodeVectorContentRequests) != 0 {
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, nodeVectorContentRequests); err != nil {
return err
}
}
}
if req.AnswerableGroups != nil {
if err := u.nodeRepo.UpdateNodeGroupByKbIDAndNodeIds(ctx, req.IDs, *req.AnswerableGroups, consts.NodePermNameAnswerable); err != nil {
return err
}
}
if req.VisibleGroups != nil {
if err := u.nodeRepo.UpdateNodeGroupByKbIDAndNodeIds(ctx, req.IDs, *req.VisibleGroups, consts.NodePermNameVisible); err != nil {
return err
}
}
if req.VisitableGroups != nil {
if err := u.nodeRepo.UpdateNodeGroupByKbIDAndNodeIds(ctx, req.IDs, *req.VisitableGroups, consts.NodePermNameVisitable); err != nil {
return err
}
}
return nil
}
func (u *NodeUsecase) SyncRagNodeStatus(ctx context.Context) error {
kbs, err := u.kbRepo.GetKnowledgeBaseList(ctx)
if err != nil {
return err
}
for _, kb := range kbs {
docIds, err := u.nodeRepo.GetNodeIdsWithoutStatusByKbId(ctx, kb.ID)
if err != nil {
u.logger.Error("get node ids without status failed",
log.String("kb_id", kb.ID),
log.Error(err))
continue
}
if len(docIds) == 0 {
continue
}
chunks := lo.Chunk(docIds, ragSyncChunkSize)
for _, chunk := range chunks {
docs, err := u.rAGService.ListDocuments(ctx, kb.DatasetID, map[string]string{
"ids": strings.Join(chunk, ","),
})
if err != nil {
u.logger.Error("list documents from RAG failed",
log.String("kb_id", kb.ID),
log.String("dataset_id", kb.DatasetID),
log.Error(err))
continue
}
if len(docs) == 0 {
continue
}
docToNodeMap, err := u.nodeRepo.GetNodeIdsByDocIds(ctx, chunk)
if err != nil {
u.logger.Error("get node ids by doc ids failed",
log.String("kb_id", kb.ID),
log.Error(err))
continue
}
type StatusInfo struct {
status string
message string
}
statusGroups := make(map[StatusInfo][]string) // status+message -> []nodeIDs
for _, doc := range docs {
nodeID, exists := docToNodeMap[doc.ID]
if !exists {
u.logger.Warn("doc_id not found in node_releases",
log.String("doc_id", doc.ID))
continue
}
statusKey := StatusInfo{
status: doc.Status,
message: doc.ProgressMsg,
}
statusGroups[statusKey] = append(statusGroups[statusKey], nodeID)
}
for statusInfo, nodeIDs := range statusGroups {
updateMap := map[string]interface{}{
"rag_info": domain.RagInfo{
Status: consts.NodeRagInfoStatus(statusInfo.status),
Message: statusInfo.message,
},
}
if err := u.nodeRepo.UpdateNodesByKbID(ctx, nodeIDs, kb.ID, updateMap); err != nil {
u.logger.Error("batch update node rag status failed",
log.String("kb_id", kb.ID),
log.Int("node_count", len(nodeIDs)),
log.String("status", statusInfo.status),
log.Error(err))
continue
}
u.logger.Debug("batch updated node rag status",
log.String("kb_id", kb.ID),
log.Int("node_count", len(nodeIDs)),
log.String("status", statusInfo.status))
}
}
}
return nil
}
func (u *NodeUsecase) NodeRestudy(ctx context.Context, req *v1.NodeRestudyReq) error {
nodeReleases, err := u.nodeRepo.GetLatestNodeReleaseByNodeIDs(ctx, req.KbId, req.NodeIds)
if err != nil {
return fmt.Errorf("get latest node release failed: %w", err)
}
for _, nodeRelease := range nodeReleases {
if nodeRelease.DocID == "" {
continue
}
if err := u.ragRepo.AsyncUpdateNodeReleaseVector(ctx, []*domain.NodeReleaseVectorRequest{
{
KBID: nodeRelease.KBID,
NodeReleaseID: nodeRelease.ID,
Action: "upsert",
},
}); err != nil {
u.logger.Error("async update node release vector failed",
log.String("node_release_id", nodeRelease.ID),
log.Error(err))
continue
}
}
return nil
}