Compare commits

..

2 Commits

Author SHA1 Message Date
jiangwei 7b8c148f3c
Merge 6bb8cb08ca into d3502e105a 2025-11-22 03:13:36 +00:00
jiangwel 6bb8cb08ca feat(chat): 添加 ChatRagOnlyRequset 结构并简化 ChatRagOnly 逻辑
重构 ChatRagOnly 方法,使用新的 ChatRagOnlyRequset 结构作为参数
移除对话管理相关代码,仅保留文档检索和敏感词检查功能
新增 MCPCall 结构用于记录调用信息
2025-11-22 11:12:59 +08:00
3 changed files with 40 additions and 107 deletions

View File

@ -23,6 +23,16 @@ type ChatRequest struct {
Info ConversationInfo `json:"-"`
}
type ChatRagOnlyRequset struct {
Message string `json:"message" validate:"required"`
CaptchaToken string `json:"captcha_token"`
KBID string `json:"-" validate:"required"`
UserInfo UserInfo `json:"user_info"`
AppType AppType `json:"app_type" validate:"required,oneof=1 2"`
}
type ConversationInfo struct {
UserInfo UserInfo `json:"user_info"`
}

15
backend/domain/mcp.go Normal file
View File

@ -0,0 +1,15 @@
package domain
import (
"time"
)
type MCPCall struct {
ID string `gorm:"primaryKey;column:id" json:"id,omitempty"`
ClientName string `gorm:"column:client_name" json:"client_name"`
ClientVersion string `gorm:"column:client_version" json:"client_version"`
Question string `gorm:"column:question" json:"question"`
Document string `gorm:"column:document" json:"document"`
RemoteIP string `gorm:"column:remote_ip" json:"remoate_ip"`
CreatedAt time.Time `gorm:"column:created_at;not null;default:now()" json:"created_at"`
}

View File

@ -301,77 +301,11 @@ func (u *ChatUsecase) Chat(ctx context.Context, req *domain.ChatRequest) (<-chan
return eventCh, nil
}
func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRequest) (<-chan domain.SSEEvent, error) {
func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRagOnlyRequset) (<-chan domain.SSEEvent, error) {
eventCh := make(chan domain.SSEEvent, 100)
go func() {
defer close(eventCh)
// get app detail and validate app
app, err := u.appRepo.GetOrCreateAppByKBIDAndType(ctx, req.KBID, req.AppType)
if err != nil {
eventCh <- domain.SSEEvent{Type: "error", Content: "app not found"}
return
}
req.KBID = app.KBID
req.AppID = app.ID
req.AppType = app.Type
// conversation management
if req.ConversationID == "" {
id, err := uuid.NewV7()
if err != nil {
u.logger.Error("failed to generate conversation uuid", log.Error(err))
id = uuid.New()
}
conversationID := id.String()
req.ConversationID = conversationID
nonce := uuid.New().String()
eventCh <- domain.SSEEvent{Type: "conversation_id", Content: conversationID}
eventCh <- domain.SSEEvent{Type: "nonce", Content: nonce}
err = u.conversationUsecase.CreateConversation(ctx, &domain.Conversation{
ID: conversationID,
Nonce: nonce,
AppID: req.AppID,
KBID: req.KBID,
Subject: req.Message,
RemoteIP: req.RemoteIP,
Info: req.Info,
CreatedAt: time.Now(),
})
if err != nil {
u.logger.Error("failed to create chat conversation", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to create chat conversation"}
return
}
} else {
if req.Nonce == "" {
eventCh <- domain.SSEEvent{Type: "error", Content: "nonce is required"}
return
}
err := u.conversationUsecase.ValidateConversationNonce(ctx, req.ConversationID, req.Nonce)
if err != nil {
u.logger.Error("failed to validate chat conversation nonce", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "validate chat conversation nonce failed"}
return
}
}
messageId := uuid.New().String()
eventCh <- domain.SSEEvent{Type: "message_id", Content: messageId}
userMessageId := uuid.New().String()
// save user question to conversation message
if err := u.conversationUsecase.CreateChatConversationMessage(ctx, req.KBID, &domain.ConversationMessage{
ID: userMessageId,
ConversationID: req.ConversationID,
KBID: req.KBID,
AppID: req.AppID,
Role: schema.User,
Content: req.Message,
RemoteIP: req.RemoteIP,
}); err != nil {
u.logger.Error("failed to save user question to conversation message", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save user question to conversation message"}
return
}
// extra1. if user set question block words then check it
blockWords, err := u.blockWordRepo.GetBlockWords(ctx, req.KBID)
if err != nil {
@ -384,35 +318,18 @@ func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRequest)
if err := questionFilter.DFA.Check(req.Message); err != nil { // exist then return err
answer := "**您的问题包含敏感词, AI 无法回答您的问题。**"
eventCh <- domain.SSEEvent{Type: "error", Content: answer}
// save ai answer and set it err
if err := u.conversationUsecase.CreateChatConversationMessage(context.Background(), req.KBID, &domain.ConversationMessage{
ID: messageId,
ConversationID: req.ConversationID,
KBID: req.KBID,
AppID: req.AppID,
Role: schema.Assistant,
Content: answer,
Provider: req.ModelInfo.Provider,
Model: string(req.ModelInfo.Model),
RemoteIP: req.RemoteIP,
ParentID: userMessageId,
}); err != nil {
u.logger.Error("failed to save assistant answer to conversation message", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save assistant answer to conversation message"}
return
}
return
}
}
if req.Info.UserInfo.AuthUserID == 0 {
if req.UserInfo.AuthUserID == 0 {
auth, _ := u.AuthRepo.GetAuthBySourceType(ctx, req.AppType.ToSourceType())
if auth != nil {
req.Info.UserInfo.AuthUserID = auth.ID
req.UserInfo.AuthUserID = auth.ID
}
}
groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.Info.UserInfo.AuthUserID)
groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.UserInfo.AuthUserID)
if err != nil {
u.logger.Error("failed to get auth groupIds", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get auth groupIds"}
@ -420,32 +337,23 @@ func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRequest)
}
// retrieve documents
messages, _, err := u.llmUsecase.FormatConversationMessages(ctx, req.ConversationID, req.KBID, groupIds)
kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, req.KBID)
if err != nil {
u.logger.Error("failed to format chat messages", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to format chat messages"}
u.logger.Error("failed to get kb", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get kb"}
return
}
u.logger.Info("message:", log.Any("schema", messages))
// extract <documents>...</documents> block from the formatted USER message
var documentsContent string
// find the last occurrence in any message (avoid picking system prompt placeholders)
for i := len(messages) - 1; i >= 0; i-- {
content := messages[i].Content
if strings.Contains(content, "<documents>") {
start := strings.Index(content, "<documents>")
end := strings.Index(content, "</documents>")
if start != -1 && end != -1 && end > start {
documentsContent = strings.TrimSpace(content[start+len("<documents>") : end])
break
}
}
rankedNodes, err := u.llmUsecase.GetRankNodes(ctx, []string{kb.DatasetID}, req.Message, groupIds, 0, nil)
if err != nil {
u.logger.Error("failed to get rank nodes", log.Error(err))
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get rank nodes"}
return
}
documents := domain.FormatNodeChunks(rankedNodes, kb.AccessSettings.BaseURL)
u.logger.Debug("documents", log.String("documents", documents))
// send only the documents part
eventCh <- domain.SSEEvent{Type: "data", Content: documentsContent}
eventCh <- domain.SSEEvent{Type: "data", Content: documents}
eventCh <- domain.SSEEvent{Type: "done"}
}()
return eventCh, nil