PandaWiki/backend/usecase/conversation.go

290 lines
9.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package usecase
import (
"context"
"fmt"
"regexp"
"github.com/samber/lo"
"github.com/chaitin/panda-wiki/domain"
"github.com/chaitin/panda-wiki/log"
"github.com/chaitin/panda-wiki/repo/cache"
"github.com/chaitin/panda-wiki/repo/ipdb"
"github.com/chaitin/panda-wiki/repo/pg"
)
type ConversationUsecase struct {
repo *pg.ConversationRepository
nodeRepo *pg.NodeRepository
geoCacheRepo *cache.GeoRepo
logger *log.Logger
ipRepo *ipdb.IPAddressRepo
authRepo *pg.AuthRepo
}
func NewConversationUsecase(
repo *pg.ConversationRepository,
nodeRepo *pg.NodeRepository,
geoCacheRepo *cache.GeoRepo,
logger *log.Logger,
ipRepo *ipdb.IPAddressRepo,
authRepo *pg.AuthRepo,
) *ConversationUsecase {
return &ConversationUsecase{
repo: repo,
nodeRepo: nodeRepo,
geoCacheRepo: geoCacheRepo,
ipRepo: ipRepo,
authRepo: authRepo,
logger: logger.WithModule("usecase.conversation"),
}
}
func (u *ConversationUsecase) CreateChatConversationMessage(ctx context.Context, kbID string, conversation *domain.ConversationMessage) error {
references := extractReferencesBlock(conversation.ID, conversation.AppID, conversation.Content)
return u.repo.CreateConversationMessage(ctx, conversation, references)
}
func (u *ConversationUsecase) GetConversationList(ctx context.Context, request *domain.ConversationListReq) (*domain.PaginatedResult[[]*domain.ConversationListItem], error) {
conversations, total, err := u.repo.GetConversationList(ctx, request)
if err != nil {
return nil, err
}
// get feedback info
conversationIDs := make([]string, 0, len(conversations))
// get all conversation authID
authIDs := make([]uint, 0, len(conversations))
for _, c := range conversations {
conversationIDs = append(conversationIDs, c.ID)
// 检查 s_id 是否有效,避免查询无效数据
if c.Info.UserInfo.AuthUserID != 0 {
authIDs = append(authIDs, c.Info.UserInfo.AuthUserID)
}
}
// 遍历拿到的c去数据库里面搜索最新的用户回复
feedbackMap, err := u.repo.GetConversationFeedBackInfoByIDs(ctx, conversationIDs)
if err != nil {
u.logger.Error("get latest feedback by conversation id failed", log.Error(err))
}
// get user info according authIDs
authMap, err := u.authRepo.GetAuthUserinfoByIDs(ctx, authIDs)
if err != nil {
u.logger.Error("get user info failed", log.Error(err))
}
// get ip address
ipAddressMap := make(map[string]*domain.IPAddress)
lo.Map(conversations, func(conversation *domain.ConversationListItem, _ int) *domain.ConversationListItem {
if _, ok := ipAddressMap[conversation.RemoteIP]; !ok {
ipAddress, err := u.ipRepo.GetIPAddress(ctx, conversation.RemoteIP)
if err != nil {
u.logger.Error("get ip address failed", log.Error(err), log.String("ip", conversation.RemoteIP))
return conversation
}
ipAddressMap[conversation.RemoteIP] = ipAddress
conversation.IPAddress = ipAddress
} else {
conversation.IPAddress = ipAddressMap[conversation.RemoteIP]
}
if _, ok := feedbackMap[conversation.ID]; ok {
conversation.FeedBackInfo = feedbackMap[conversation.ID]
}
if _, ok := authMap[conversation.Info.UserInfo.AuthUserID]; ok {
conversation.Info.UserInfo = domain.UserInfo{
NickName: authMap[conversation.Info.UserInfo.AuthUserID].AuthUserInfo.Username,
Avatar: authMap[conversation.Info.UserInfo.AuthUserID].AuthUserInfo.AvatarUrl,
Email: authMap[conversation.Info.UserInfo.AuthUserID].AuthUserInfo.Email,
}
}
return conversation
})
return domain.NewPaginatedResult(conversations, total), nil
}
func (u *ConversationUsecase) GetConversationDetail(ctx context.Context, kbID, conversationID string) (*domain.ConversationDetailResp, error) {
conversation, err := u.repo.GetConversationDetail(ctx, kbID, conversationID)
if err != nil {
return nil, err
}
// get ip address
ipAddress, err := u.ipRepo.GetIPAddress(ctx, conversation.RemoteIP)
if err != nil {
u.logger.Error("get ip address failed", log.Error(err), log.String("ip", conversation.RemoteIP))
} else {
conversation.IPAddress = ipAddress
}
// get messages
messages, err := u.repo.GetConversationMessagesByID(ctx, conversationID)
if err != nil {
return nil, err
}
conversation.Messages = messages
// get references
references, err := u.repo.GetConversationReferences(ctx, conversationID)
if err != nil {
return nil, err
}
conversation.References = references
return conversation, nil
}
func extractReferencesBlock(conversationID, appID, text string) []*domain.ConversationReference {
// match whole reference block
reBlock := regexp.MustCompile(`(?ms)((?:>|\\u003e)\s*\[\d+\]\.\s*\[.*?\]\(.*?\)\s*\n?)+$`)
// find the last match index
lastIndex := -1
allMatches := reBlock.FindAllStringIndex(text, -1)
if len(allMatches) > 0 {
lastIndex = allMatches[len(allMatches)-1][0]
}
if lastIndex == -1 {
return nil
}
// extract all references in the last reference block
block := text[lastIndex:]
reLine := regexp.MustCompile(`(?m)^(?:>|\\u003e)\s*\[(\d+)\]\.\s*\[(.*?)\]\((.*?)\)`)
matches := reLine.FindAllStringSubmatch(block, -1)
refs := make([]*domain.ConversationReference, 0)
for _, match := range matches {
if len(match) == 4 {
refs = append(refs, &domain.ConversationReference{
Name: match[2],
URL: match[3],
ConversationID: conversationID,
AppID: appID,
})
}
}
return refs
}
func (u *ConversationUsecase) ValidateConversationNonce(ctx context.Context, conversationID, nonce string) error {
return u.repo.ValidateConversationNonce(ctx, conversationID, nonce)
}
func (u *ConversationUsecase) CreateConversation(ctx context.Context, conversation *domain.Conversation) error {
if err := u.repo.CreateConversation(ctx, conversation); err != nil {
return err
}
remoteIP := conversation.RemoteIP
ipAddress, err := u.ipRepo.GetIPAddress(ctx, remoteIP)
if err != nil {
u.logger.Warn("get ip address failed", log.Error(err), log.String("ip", remoteIP), log.String("conversation_id", conversation.ID))
} else {
location := fmt.Sprintf("%s|%s|%s", ipAddress.Country, ipAddress.Province, ipAddress.City)
if err := u.geoCacheRepo.SetGeo(ctx, conversation.KBID, location); err != nil {
u.logger.Warn("set geo cache failed", log.Error(err), log.String("conversation_id", conversation.ID), log.String("ip", remoteIP))
}
}
return nil
}
func (u *ConversationUsecase) FeedBack(ctx context.Context, feedback *domain.FeedbackRequest) error {
// 先查询数据库看看目前message的信息
messages, err := u.repo.GetConversationMessagesDetailByID(ctx, feedback.MessageId)
if err != nil {
return err
}
u.logger.Debug("feedback info", log.Any("feedback_info", messages.Info))
// 后端校验一下,只是允许用户进行一次投票
if messages.Info.Score == 0 {
// 用户可以提供建议
if err := u.repo.UpdateMessageFeedback(ctx, feedback); err != nil {
return err
}
} else {
return fmt.Errorf("already voted for this message, please do not vote again")
}
return nil
}
func (u *ConversationUsecase) GetMessageList(ctx context.Context, req *domain.MessageListReq) (*domain.PaginatedResult[[]*domain.ConversationMessageListItem], error) {
total, messageList, err := u.repo.GetMessageFeedBackList(ctx, req)
if err != nil {
return nil, err
}
// get auth userinfo --> auth_user_id is not 0
authIDs := make([]uint, 0, len(messageList))
for _, message := range messageList {
if message.ConversationInfo.UserInfo.AuthUserID != 0 {
authIDs = append(authIDs, message.ConversationInfo.UserInfo.AuthUserID)
}
}
// get user info according authIDs
authMap, err := u.authRepo.GetAuthUserinfoByIDs(ctx, authIDs)
if err != nil {
u.logger.Error("get user info failed", log.Error(err))
}
// get ip address
ipAddressMap := make(map[string]*domain.IPAddress)
lo.Map(messageList, func(message *domain.ConversationMessageListItem, _ int) *domain.ConversationMessageListItem {
if _, ok := ipAddressMap[message.RemoteIP]; !ok {
ipAddress, err := u.ipRepo.GetIPAddress(ctx, message.RemoteIP)
if err != nil {
u.logger.Error("get ip address failed", log.Error(err), log.String("ip", message.RemoteIP))
return message
}
ipAddressMap[message.RemoteIP] = ipAddress
message.IPAddress = ipAddress
} else {
message.IPAddress = ipAddressMap[message.RemoteIP]
}
if _, ok := authMap[message.ConversationInfo.UserInfo.AuthUserID]; ok {
message.ConversationInfo.UserInfo = domain.UserInfo{
NickName: authMap[message.ConversationInfo.UserInfo.AuthUserID].AuthUserInfo.Username,
Avatar: authMap[message.ConversationInfo.UserInfo.AuthUserID].AuthUserInfo.AvatarUrl,
Email: authMap[message.ConversationInfo.UserInfo.AuthUserID].AuthUserInfo.Email,
}
}
return message
})
return domain.NewPaginatedResult(messageList, uint64(total)), nil
}
func (u *ConversationUsecase) GetMessageDetail(ctx context.Context, kbId, messageId string) (*domain.ConversationMessage, error) {
message, err := u.repo.GetConversationMessagesDetailByKbID(ctx, kbId, messageId)
if err != nil {
return nil, err
}
return message, nil
}
func (u *ConversationUsecase) GetShareConversationDetail(ctx context.Context, kbID, conversationID string) (*domain.ShareConversationDetailResp, error) {
conversation, err := u.repo.GetConversationDetail(ctx, kbID, conversationID)
if err != nil {
return nil, err
}
// get messages
messages, err := u.repo.GetConversationMessagesByID(ctx, conversationID)
if err != nil {
return nil, err
}
var shareMessages []*domain.ShareConversationMessage
for _, message := range messages {
shareMessages = append(shareMessages, &domain.ShareConversationMessage{
Role: message.Role,
Content: message.Content,
CreatedAt: message.CreatedAt,
})
}
shareConversationDetail := domain.ShareConversationDetailResp{
ID: conversation.ID,
Subject: conversation.Subject,
CreatedAt: conversation.CreatedAt,
Messages: shareMessages,
}
conversation.Messages = messages
return &shareConversationDetail, nil
}