mirror of https://github.com/chaitin/PandaWiki.git
Compare commits
7 Commits
5ba524b19f
...
9c02e5d93b
| Author | SHA1 | Date |
|---|---|---|
|
|
9c02e5d93b | |
|
|
18e376c40d | |
|
|
a6f4688b88 | |
|
|
575f51f0ea | |
|
|
83f6853716 | |
|
|
3dae8e8d01 | |
|
|
2e1e1848c4 |
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -16,8 +15,6 @@ import (
|
|||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/samber/lo"
|
||||
"github.com/samber/lo/parallel"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/chaitin/panda-wiki/config"
|
||||
"github.com/chaitin/panda-wiki/domain"
|
||||
|
|
@ -39,6 +36,12 @@ type LLMUsecase struct {
|
|||
modelkit *modelkit.ModelKit
|
||||
}
|
||||
|
||||
const (
|
||||
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
|
||||
summaryMaxChunks = 4 // max chunks to process for summary
|
||||
summaryAggregateLimit = 8192 // max tokens for aggregating summaries
|
||||
)
|
||||
|
||||
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
||||
tiktoken.SetBpeLoader(&utils.Localloader{})
|
||||
modelkit := modelkit.NewModelKit(logger.Logger)
|
||||
|
|
@ -197,52 +200,91 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
|
|||
return "", err
|
||||
}
|
||||
|
||||
chunks, err := u.SplitByTokenLimit(content, int(math.Floor(1024*32*0.95)))
|
||||
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sem := semaphore.NewWeighted(int64(10))
|
||||
summaries := parallel.Map(chunks, func(chunk string, _ int) string {
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
u.logger.Error("Failed to acquire semaphore for chunk: ", log.Error(err))
|
||||
return ""
|
||||
}
|
||||
defer sem.Release(1)
|
||||
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "你是文档总结助手,请根据文档内容总结出文档的摘要。摘要是纯文本,应该简洁明了,不要超过160个字。",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, chunk),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
u.logger.Error("Failed to generate summary for chunk: ", log.Error(err))
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(summary, "<think>") {
|
||||
// remove <think> body </think>
|
||||
endIndex := strings.Index(summary, "</think>")
|
||||
if endIndex != -1 {
|
||||
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
|
||||
}
|
||||
}
|
||||
return summary
|
||||
})
|
||||
// 使用lo.Filter处理错误
|
||||
defeatSummary := lo.Filter(summaries, func(summary string, index int) bool {
|
||||
return summary == ""
|
||||
})
|
||||
if len(defeatSummary) > 0 {
|
||||
return "", fmt.Errorf("failed to generate summaries for all chunks: %d/%d", len(defeatSummary), len(chunks))
|
||||
if len(chunks) > summaryMaxChunks {
|
||||
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
|
||||
chunks = chunks[:summaryMaxChunks]
|
||||
}
|
||||
|
||||
contents, err := u.SplitByTokenLimit(strings.Join(summaries, "\n\n"), int(math.Floor(1024*32*0.95)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
summaries := make([]string, 0, len(chunks))
|
||||
for idx, chunk := range chunks {
|
||||
summary, err := u.requestSummary(ctx, chatModel, name, chunk)
|
||||
if err != nil {
|
||||
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
|
||||
continue
|
||||
}
|
||||
if summary == "" {
|
||||
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
|
||||
continue
|
||||
}
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
|
||||
if len(summaries) == 0 {
|
||||
return "", fmt.Errorf("failed to generate summary for document %s", name)
|
||||
}
|
||||
|
||||
// Iteratively aggregate summaries if needed
|
||||
for len(summaries) > 1 {
|
||||
joined := strings.Join(summaries, "\n\n")
|
||||
tokens, err := u.countTokens(joined)
|
||||
if err != nil {
|
||||
u.logger.Warn("Failed to count tokens for aggregation, proceeding anyway", log.Error(err))
|
||||
break
|
||||
}
|
||||
if tokens <= summaryAggregateLimit {
|
||||
break
|
||||
}
|
||||
// If still too large, aggregate in batches
|
||||
u.logger.Debug("aggregating summaries in batches", log.Int("current_summaries", len(summaries)), log.Int("tokens", tokens))
|
||||
batchSize := 2
|
||||
newSummaries := make([]string, 0, (len(summaries)+batchSize-1)/batchSize)
|
||||
for i := 0; i < len(summaries); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(summaries) {
|
||||
end = len(summaries)
|
||||
}
|
||||
batch := strings.Join(summaries[i:end], "\n\n")
|
||||
summary, err := u.requestSummary(ctx, chatModel, name, batch)
|
||||
if err != nil {
|
||||
u.logger.Error("Failed to aggregate summary batch", log.Int("batch_start", i), log.Error(err))
|
||||
// Fallback: use the first summary in the batch
|
||||
newSummaries = append(newSummaries, summaries[i])
|
||||
continue
|
||||
}
|
||||
newSummaries = append(newSummaries, summary)
|
||||
}
|
||||
summaries = newSummaries
|
||||
}
|
||||
|
||||
joined := strings.Join(summaries, "\n\n")
|
||||
finalSummary, err := u.requestSummary(ctx, chatModel, name, joined)
|
||||
if err != nil {
|
||||
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
|
||||
// Fallback: return the joined summaries directly
|
||||
if len(joined) > 500 {
|
||||
return joined[:500] + "...", nil
|
||||
}
|
||||
return joined, nil
|
||||
}
|
||||
return finalSummary, nil
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) trimThinking(summary string) string {
|
||||
if !strings.HasPrefix(summary, "<think>") {
|
||||
return summary
|
||||
}
|
||||
endIndex := strings.Index(summary, "</think>")
|
||||
if endIndex == -1 {
|
||||
return summary
|
||||
}
|
||||
return strings.TrimSpace(summary[endIndex+len("</think>"):])
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) requestSummary(ctx context.Context, chatModel model.BaseChatModel, name, content string) (string, error) {
|
||||
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
|
||||
{
|
||||
Role: "system",
|
||||
|
|
@ -250,20 +292,22 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
|
|||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, contents[0]),
|
||||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, content),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.HasPrefix(summary, "<think>") {
|
||||
// remove <think> body </think>
|
||||
endIndex := strings.Index(summary, "</think>")
|
||||
if endIndex != -1 {
|
||||
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
|
||||
}
|
||||
return strings.TrimSpace(u.trimThinking(summary)), nil
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) countTokens(text string) (int, error) {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get encoding: %w", err)
|
||||
}
|
||||
return summary, nil
|
||||
tokens := encoding.Encode(text, nil, nil)
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue