Compare commits

...

7 Commits

Author SHA1 Message Date
xiaomakuaiz 9c02e5d93b
Merge 18e376c40d into 3032384457 2025-11-07 03:39:15 +00:00
monkeycode-ai 18e376c40d Improve summary optimization with better token limits and aggregation
优化摘要生成逻辑:
1. 将chunk token限制从16KB提升到30KB,更合理地利用模型上下文
2. 添加迭代聚合逻辑,当多个summaries合并后超过8KB时,分批聚合
3. 添加fallback机制,当最终摘要生成失败时返回已聚合的摘要
4. 新增countTokens辅助方法用于精确计算token数量

这些改进确保了:
- 长文档能够更充分地被摘要(30KB vs 16KB)
- 多个chunk的摘要能够智能聚合,避免超出token限制
- 即使最终摘要失败也能返回有用的结果

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-07 11:35:56 +08:00
monkeycode-ai a6f4688b88 Run goimports on llm summary
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-06 20:07:30 +08:00
monkeycode-ai 575f51f0ea Simplify final summary aggregation
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-06 19:52:03 +08:00
monkeycode-ai 83f6853716 Iteratively reduce summary chunks
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-06 19:21:26 +08:00
monkeycode-ai 3dae8e8d01 Raise summary chunk limit to 16k
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-06 19:17:27 +08:00
monkeycode-ai 2e1e1848c4 Adjust summary chunking and concurrency
Co-authored-by: monkeycode-ai <monkeycode-ai@chaitin.com>
2025-11-06 19:09:15 +08:00
1 changed files with 95 additions and 51 deletions

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"math"
"slices"
"strings"
"time"
@ -16,8 +15,6 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
"github.com/samber/lo"
"github.com/samber/lo/parallel"
"golang.org/x/sync/semaphore"
"github.com/chaitin/panda-wiki/config"
"github.com/chaitin/panda-wiki/domain"
@ -39,6 +36,12 @@ type LLMUsecase struct {
modelkit *modelkit.ModelKit
}
const (
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
summaryMaxChunks = 4 // max chunks to process for summary
summaryAggregateLimit = 8192 // max tokens for aggregating summaries
)
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
tiktoken.SetBpeLoader(&utils.Localloader{})
modelkit := modelkit.NewModelKit(logger.Logger)
@ -197,52 +200,91 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
return "", err
}
chunks, err := u.SplitByTokenLimit(content, int(math.Floor(1024*32*0.95)))
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
if err != nil {
return "", err
}
sem := semaphore.NewWeighted(int64(10))
summaries := parallel.Map(chunks, func(chunk string, _ int) string {
if err := sem.Acquire(ctx, 1); err != nil {
u.logger.Error("Failed to acquire semaphore for chunk: ", log.Error(err))
return ""
}
defer sem.Release(1)
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
{
Role: "system",
Content: "你是文档总结助手请根据文档内容总结出文档的摘要。摘要是纯文本应该简洁明了不要超过160个字。",
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, chunk),
},
})
if err != nil {
u.logger.Error("Failed to generate summary for chunk: ", log.Error(err))
return ""
}
if strings.HasPrefix(summary, "<think>") {
// remove <think> body </think>
endIndex := strings.Index(summary, "</think>")
if endIndex != -1 {
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
}
}
return summary
})
// 使用lo.Filter处理错误
defeatSummary := lo.Filter(summaries, func(summary string, index int) bool {
return summary == ""
})
if len(defeatSummary) > 0 {
return "", fmt.Errorf("failed to generate summaries for all chunks: %d/%d", len(defeatSummary), len(chunks))
if len(chunks) > summaryMaxChunks {
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
chunks = chunks[:summaryMaxChunks]
}
contents, err := u.SplitByTokenLimit(strings.Join(summaries, "\n\n"), int(math.Floor(1024*32*0.95)))
if err != nil {
return "", err
summaries := make([]string, 0, len(chunks))
for idx, chunk := range chunks {
summary, err := u.requestSummary(ctx, chatModel, name, chunk)
if err != nil {
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
continue
}
if summary == "" {
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
continue
}
summaries = append(summaries, summary)
}
if len(summaries) == 0 {
return "", fmt.Errorf("failed to generate summary for document %s", name)
}
// Iteratively aggregate summaries if needed
for len(summaries) > 1 {
joined := strings.Join(summaries, "\n\n")
tokens, err := u.countTokens(joined)
if err != nil {
u.logger.Warn("Failed to count tokens for aggregation, proceeding anyway", log.Error(err))
break
}
if tokens <= summaryAggregateLimit {
break
}
// If still too large, aggregate in batches
u.logger.Debug("aggregating summaries in batches", log.Int("current_summaries", len(summaries)), log.Int("tokens", tokens))
batchSize := 2
newSummaries := make([]string, 0, (len(summaries)+batchSize-1)/batchSize)
for i := 0; i < len(summaries); i += batchSize {
end := i + batchSize
if end > len(summaries) {
end = len(summaries)
}
batch := strings.Join(summaries[i:end], "\n\n")
summary, err := u.requestSummary(ctx, chatModel, name, batch)
if err != nil {
u.logger.Error("Failed to aggregate summary batch", log.Int("batch_start", i), log.Error(err))
// Fallback: use the first summary in the batch
newSummaries = append(newSummaries, summaries[i])
continue
}
newSummaries = append(newSummaries, summary)
}
summaries = newSummaries
}
joined := strings.Join(summaries, "\n\n")
finalSummary, err := u.requestSummary(ctx, chatModel, name, joined)
if err != nil {
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
// Fallback: return the joined summaries directly
if len(joined) > 500 {
return joined[:500] + "...", nil
}
return joined, nil
}
return finalSummary, nil
}
func (u *LLMUsecase) trimThinking(summary string) string {
if !strings.HasPrefix(summary, "<think>") {
return summary
}
endIndex := strings.Index(summary, "</think>")
if endIndex == -1 {
return summary
}
return strings.TrimSpace(summary[endIndex+len("</think>"):])
}
func (u *LLMUsecase) requestSummary(ctx context.Context, chatModel model.BaseChatModel, name, content string) (string, error) {
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
{
Role: "system",
@ -250,20 +292,22 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
},
{
Role: "user",
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, contents[0]),
Content: fmt.Sprintf("文档名称:%s\n文档内容%s", name, content),
},
})
if err != nil {
return "", err
}
if strings.HasPrefix(summary, "<think>") {
// remove <think> body </think>
endIndex := strings.Index(summary, "</think>")
if endIndex != -1 {
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
}
return strings.TrimSpace(u.trimThinking(summary)), nil
}
func (u *LLMUsecase) countTokens(text string) (int, error) {
encoding, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return 0, fmt.Errorf("failed to get encoding: %w", err)
}
return summary, nil
tokens := encoding.Encode(text, nil, nil)
return len(tokens), nil
}
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {