This commit is contained in:
xiaomakuaiz 2025-11-07 17:06:33 +08:00 committed by GitHub
commit 282dc63242
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 54 additions and 52 deletions

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"math"
"slices"
"strings"
"time"
@ -16,8 +15,6 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
"github.com/samber/lo"
"github.com/samber/lo/parallel"
"golang.org/x/sync/semaphore"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/domain"
@ -39,6 +36,11 @@ type LLMUsecase struct {
modelkit *modelkit.ModelKit
}
const (
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
summaryMaxChunks = 4 // max chunks to process for summary
)
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
tiktoken.SetBpeLoader(&utils.Localloader{})
modelkit := modelkit.NewModelKit(logger.Logger)
@ -197,52 +199,59 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
return "", err
}
chunks, err := u.SplitByTokenLimit(content, int(math.Floor(1024*32*0.95)))
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
if err != nil {
return "", err
}
sem := semaphore.NewWeighted(int64(10))
summaries := parallel.Map(chunks, func(chunk string, _ int) string {
if err := sem.Acquire(ctx, 1); err != nil {
u.logger.Error("Failed to acquire semaphore for chunk: ", log.Error(err))
return ""
}
defer sem.Release(1)
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
{
Role: "system",
Content: "你是文档总结助手请根据文档内容总结出文档的摘要。摘要是纯文本应该简洁明了不要超过160个字。",
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, chunk),
},
})
if err != nil {
u.logger.Error("Failed to generate summary for chunk: ", log.Error(err))
return ""
}
if strings.HasPrefix(summary, "<think>") {
// remove <think> body </think>
endIndex := strings.Index(summary, "</think>")
if endIndex != -1 {
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
}
}
return summary
})
// 使用lo.Filter处理错误
defeatSummary := lo.Filter(summaries, func(summary string, index int) bool {
return summary == ""
})
if len(defeatSummary) > 0 {
return "", fmt.Errorf("failed to generate summaries for all chunks: %d/%d", len(defeatSummary), len(chunks))
if len(chunks) > summaryMaxChunks {
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
chunks = chunks[:summaryMaxChunks]
}
contents, err := u.SplitByTokenLimit(strings.Join(summaries, "\n\n"), int(math.Floor(1024*32*0.95)))
if err != nil {
return "", err
summaries := make([]string, 0, len(chunks))
for idx, chunk := range chunks {
summary, err := u.requestSummary(ctx, chatModel, name, chunk)
if err != nil {
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
continue
}
if summary == "" {
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
continue
}
summaries = append(summaries, summary)
}
if len(summaries) == 0 {
return "", fmt.Errorf("failed to generate summary for document %s", name)
}
// Join all summaries and generate final summary
joined := strings.Join(summaries, "\n\n")
finalSummary, err := u.requestSummary(ctx, chatModel, name, joined)
if err != nil {
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
// Fallback: return the joined summaries directly
if len(joined) > 500 {
return joined[:500] + "...", nil
}
return joined, nil
}
return finalSummary, nil
}
func (u *LLMUsecase) trimThinking(summary string) string {
if !strings.HasPrefix(summary, "<think>") {
return summary
}
endIndex := strings.Index(summary, "</think>")
if endIndex == -1 {
return summary
}
return strings.TrimSpace(summary[endIndex+len("</think>"):])
}
func (u *LLMUsecase) requestSummary(ctx context.Context, chatModel model.BaseChatModel, name, content string) (string, error) {
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
{
Role: "system",
@ -250,20 +259,13 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, contents[0]),
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, content),
},
})
if err != nil {
return "", err
}
if strings.HasPrefix(summary, "<think>") {
// remove <think> body </think>
endIndex := strings.Index(summary, "</think>")
if endIndex != -1 {
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
}
}
return summary, nil
return strings.TrimSpace(u.trimThinking(summary)), nil
}
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {