|
|
|
@ -301,11 +301,77 @@ func (u *ChatUsecase) Chat(ctx context.Context, req *domain.ChatRequest) (<-chan
|
|
|
|
return eventCh, nil
|
|
|
|
return eventCh, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRagOnlyRequset) (<-chan domain.SSEEvent, error) {
|
|
|
|
func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRequest) (<-chan domain.SSEEvent, error) {
|
|
|
|
eventCh := make(chan domain.SSEEvent, 100)
|
|
|
|
eventCh := make(chan domain.SSEEvent, 100)
|
|
|
|
go func() {
|
|
|
|
go func() {
|
|
|
|
defer close(eventCh)
|
|
|
|
defer close(eventCh)
|
|
|
|
|
|
|
|
// get app detail and validate app
|
|
|
|
|
|
|
|
app, err := u.appRepo.GetOrCreateAppByKBIDAndType(ctx, req.KBID, req.AppType)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "app not found"}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
req.KBID = app.KBID
|
|
|
|
|
|
|
|
req.AppID = app.ID
|
|
|
|
|
|
|
|
req.AppType = app.Type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// conversation management
|
|
|
|
|
|
|
|
if req.ConversationID == "" {
|
|
|
|
|
|
|
|
id, err := uuid.NewV7()
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
u.logger.Error("failed to generate conversation uuid", log.Error(err))
|
|
|
|
|
|
|
|
id = uuid.New()
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
conversationID := id.String()
|
|
|
|
|
|
|
|
req.ConversationID = conversationID
|
|
|
|
|
|
|
|
nonce := uuid.New().String()
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "conversation_id", Content: conversationID}
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "nonce", Content: nonce}
|
|
|
|
|
|
|
|
err = u.conversationUsecase.CreateConversation(ctx, &domain.Conversation{
|
|
|
|
|
|
|
|
ID: conversationID,
|
|
|
|
|
|
|
|
Nonce: nonce,
|
|
|
|
|
|
|
|
AppID: req.AppID,
|
|
|
|
|
|
|
|
KBID: req.KBID,
|
|
|
|
|
|
|
|
Subject: req.Message,
|
|
|
|
|
|
|
|
RemoteIP: req.RemoteIP,
|
|
|
|
|
|
|
|
Info: req.Info,
|
|
|
|
|
|
|
|
CreatedAt: time.Now(),
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
u.logger.Error("failed to create chat conversation", log.Error(err))
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to create chat conversation"}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
if req.Nonce == "" {
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "nonce is required"}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
err := u.conversationUsecase.ValidateConversationNonce(ctx, req.ConversationID, req.Nonce)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
u.logger.Error("failed to validate chat conversation nonce", log.Error(err))
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "validate chat conversation nonce failed"}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messageId := uuid.New().String()
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "message_id", Content: messageId}
|
|
|
|
|
|
|
|
userMessageId := uuid.New().String()
|
|
|
|
|
|
|
|
// save user question to conversation message
|
|
|
|
|
|
|
|
if err := u.conversationUsecase.CreateChatConversationMessage(ctx, req.KBID, &domain.ConversationMessage{
|
|
|
|
|
|
|
|
ID: userMessageId,
|
|
|
|
|
|
|
|
ConversationID: req.ConversationID,
|
|
|
|
|
|
|
|
KBID: req.KBID,
|
|
|
|
|
|
|
|
AppID: req.AppID,
|
|
|
|
|
|
|
|
Role: schema.User,
|
|
|
|
|
|
|
|
Content: req.Message,
|
|
|
|
|
|
|
|
RemoteIP: req.RemoteIP,
|
|
|
|
|
|
|
|
}); err != nil {
|
|
|
|
|
|
|
|
u.logger.Error("failed to save user question to conversation message", log.Error(err))
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save user question to conversation message"}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
// extra1. if user set question block words then check it
|
|
|
|
// extra1. if user set question block words then check it
|
|
|
|
blockWords, err := u.blockWordRepo.GetBlockWords(ctx, req.KBID)
|
|
|
|
blockWords, err := u.blockWordRepo.GetBlockWords(ctx, req.KBID)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
@ -318,18 +384,35 @@ func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRagOnlyRe
|
|
|
|
if err := questionFilter.DFA.Check(req.Message); err != nil { // exist then return err
|
|
|
|
if err := questionFilter.DFA.Check(req.Message); err != nil { // exist then return err
|
|
|
|
answer := "**您的问题包含敏感词, AI 无法回答您的问题。**"
|
|
|
|
answer := "**您的问题包含敏感词, AI 无法回答您的问题。**"
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: answer}
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: answer}
|
|
|
|
|
|
|
|
// save ai answer and set it err
|
|
|
|
|
|
|
|
if err := u.conversationUsecase.CreateChatConversationMessage(context.Background(), req.KBID, &domain.ConversationMessage{
|
|
|
|
|
|
|
|
ID: messageId,
|
|
|
|
|
|
|
|
ConversationID: req.ConversationID,
|
|
|
|
|
|
|
|
KBID: req.KBID,
|
|
|
|
|
|
|
|
AppID: req.AppID,
|
|
|
|
|
|
|
|
Role: schema.Assistant,
|
|
|
|
|
|
|
|
Content: answer,
|
|
|
|
|
|
|
|
Provider: req.ModelInfo.Provider,
|
|
|
|
|
|
|
|
Model: string(req.ModelInfo.Model),
|
|
|
|
|
|
|
|
RemoteIP: req.RemoteIP,
|
|
|
|
|
|
|
|
ParentID: userMessageId,
|
|
|
|
|
|
|
|
}); err != nil {
|
|
|
|
|
|
|
|
u.logger.Error("failed to save assistant answer to conversation message", log.Error(err))
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to save assistant answer to conversation message"}
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
}
|
|
|
|
return
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if req.UserInfo.AuthUserID == 0 {
|
|
|
|
if req.Info.UserInfo.AuthUserID == 0 {
|
|
|
|
auth, _ := u.AuthRepo.GetAuthBySourceType(ctx, req.AppType.ToSourceType())
|
|
|
|
auth, _ := u.AuthRepo.GetAuthBySourceType(ctx, req.AppType.ToSourceType())
|
|
|
|
if auth != nil {
|
|
|
|
if auth != nil {
|
|
|
|
req.UserInfo.AuthUserID = auth.ID
|
|
|
|
req.Info.UserInfo.AuthUserID = auth.ID
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.UserInfo.AuthUserID)
|
|
|
|
groupIds, err := u.AuthRepo.GetAuthGroupIdsWithParentsByAuthId(ctx, req.Info.UserInfo.AuthUserID)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
u.logger.Error("failed to get auth groupIds", log.Error(err))
|
|
|
|
u.logger.Error("failed to get auth groupIds", log.Error(err))
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get auth groupIds"}
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get auth groupIds"}
|
|
|
|
@ -337,23 +420,32 @@ func (u *ChatUsecase) ChatRagOnly(ctx context.Context, req *domain.ChatRagOnlyRe
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// retrieve documents
|
|
|
|
// retrieve documents
|
|
|
|
kb, err := u.kbRepo.GetKnowledgeBaseByID(ctx, req.KBID)
|
|
|
|
messages, _, err := u.llmUsecase.FormatConversationMessages(ctx, req.ConversationID, req.KBID, groupIds)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
u.logger.Error("failed to get kb", log.Error(err))
|
|
|
|
u.logger.Error("failed to format chat messages", log.Error(err))
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get kb"}
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to format chat messages"}
|
|
|
|
return
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rankedNodes, err := u.llmUsecase.GetRankNodes(ctx, []string{kb.DatasetID}, req.Message, groupIds, 0, nil)
|
|
|
|
u.logger.Info("message:", log.Any("schema", messages))
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
u.logger.Error("failed to get rank nodes", log.Error(err))
|
|
|
|
// extract <documents>...</documents> block from the formatted USER message
|
|
|
|
eventCh <- domain.SSEEvent{Type: "error", Content: "failed to get rank nodes"}
|
|
|
|
var documentsContent string
|
|
|
|
return
|
|
|
|
// find the last occurrence in any message (avoid picking system prompt placeholders)
|
|
|
|
|
|
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
|
|
|
|
|
|
content := messages[i].Content
|
|
|
|
|
|
|
|
if strings.Contains(content, "<documents>") {
|
|
|
|
|
|
|
|
start := strings.Index(content, "<documents>")
|
|
|
|
|
|
|
|
end := strings.Index(content, "</documents>")
|
|
|
|
|
|
|
|
if start != -1 && end != -1 && end > start {
|
|
|
|
|
|
|
|
documentsContent = strings.TrimSpace(content[start+len("<documents>") : end])
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
documents := domain.FormatNodeChunks(rankedNodes, kb.AccessSettings.BaseURL)
|
|
|
|
|
|
|
|
u.logger.Debug("documents", log.String("documents", documents))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// send only the documents part
|
|
|
|
// send only the documents part
|
|
|
|
eventCh <- domain.SSEEvent{Type: "data", Content: documents}
|
|
|
|
eventCh <- domain.SSEEvent{Type: "data", Content: documentsContent}
|
|
|
|
|
|
|
|
|
|
|
|
eventCh <- domain.SSEEvent{Type: "done"}
|
|
|
|
eventCh <- domain.SSEEvent{Type: "done"}
|
|
|
|
}()
|
|
|
|
}()
|
|
|
|
return eventCh, nil
|
|
|
|
return eventCh, nil
|
|
|
|
|