mirror of https://github.com/chaitin/PandaWiki.git
Compare commits
2 Commits
5ba524b19f
...
9c02e5d93b
| Author | SHA1 | Date |
|---|---|---|
|
|
9c02e5d93b | |
|
|
18e376c40d |
|
|
@ -37,8 +37,9 @@ type LLMUsecase struct {
|
|||
}
|
||||
|
||||
const (
|
||||
summaryChunkTokenLimit = 16384
|
||||
summaryMaxChunks = 4
|
||||
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
|
||||
summaryMaxChunks = 4 // max chunks to process for summary
|
||||
summaryAggregateLimit = 8192 // max tokens for aggregating summaries
|
||||
)
|
||||
|
||||
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
||||
|
|
@ -226,8 +227,50 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
|
|||
return "", fmt.Errorf("failed to generate summary for document %s", name)
|
||||
}
|
||||
|
||||
// Iteratively aggregate summaries if needed
|
||||
for len(summaries) > 1 {
|
||||
joined := strings.Join(summaries, "\n\n")
|
||||
return u.requestSummary(ctx, chatModel, name, joined)
|
||||
tokens, err := u.countTokens(joined)
|
||||
if err != nil {
|
||||
u.logger.Warn("Failed to count tokens for aggregation, proceeding anyway", log.Error(err))
|
||||
break
|
||||
}
|
||||
if tokens <= summaryAggregateLimit {
|
||||
break
|
||||
}
|
||||
// If still too large, aggregate in batches
|
||||
u.logger.Debug("aggregating summaries in batches", log.Int("current_summaries", len(summaries)), log.Int("tokens", tokens))
|
||||
batchSize := 2
|
||||
newSummaries := make([]string, 0, (len(summaries)+batchSize-1)/batchSize)
|
||||
for i := 0; i < len(summaries); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(summaries) {
|
||||
end = len(summaries)
|
||||
}
|
||||
batch := strings.Join(summaries[i:end], "\n\n")
|
||||
summary, err := u.requestSummary(ctx, chatModel, name, batch)
|
||||
if err != nil {
|
||||
u.logger.Error("Failed to aggregate summary batch", log.Int("batch_start", i), log.Error(err))
|
||||
// Fallback: use the first summary in the batch
|
||||
newSummaries = append(newSummaries, summaries[i])
|
||||
continue
|
||||
}
|
||||
newSummaries = append(newSummaries, summary)
|
||||
}
|
||||
summaries = newSummaries
|
||||
}
|
||||
|
||||
joined := strings.Join(summaries, "\n\n")
|
||||
finalSummary, err := u.requestSummary(ctx, chatModel, name, joined)
|
||||
if err != nil {
|
||||
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
|
||||
// Fallback: return the joined summaries directly
|
||||
if len(joined) > 500 {
|
||||
return joined[:500] + "...", nil
|
||||
}
|
||||
return joined, nil
|
||||
}
|
||||
return finalSummary, nil
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) trimThinking(summary string) string {
|
||||
|
|
@ -258,6 +301,15 @@ func (u *LLMUsecase) requestSummary(ctx context.Context, chatModel model.BaseCha
|
|||
return strings.TrimSpace(u.trimThinking(summary)), nil
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) countTokens(text string) (int, error) {
|
||||
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get encoding: %w", err)
|
||||
}
|
||||
tokens := encoding.Encode(text, nil, nil)
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
||||
if maxTokens <= 0 {
|
||||
return nil, fmt.Errorf("maxTokens must be greater than 0")
|
||||
|
|
|
|||
Loading…
Reference in New Issue