mirror of https://github.com/chaitin/PandaWiki.git
Compare commits
2 Commits
5ba524b19f
...
9c02e5d93b
| Author | SHA1 | Date |
|---|---|---|
|
|
9c02e5d93b | |
|
|
18e376c40d |
|
|
@ -37,8 +37,9 @@ type LLMUsecase struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
summaryChunkTokenLimit = 16384
|
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
|
||||||
summaryMaxChunks = 4
|
summaryMaxChunks = 4 // max chunks to process for summary
|
||||||
|
summaryAggregateLimit = 8192 // max tokens for aggregating summaries
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
||||||
|
|
@ -226,8 +227,50 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
|
||||||
return "", fmt.Errorf("failed to generate summary for document %s", name)
|
return "", fmt.Errorf("failed to generate summary for document %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Iteratively aggregate summaries if needed
|
||||||
|
for len(summaries) > 1 {
|
||||||
|
joined := strings.Join(summaries, "\n\n")
|
||||||
|
tokens, err := u.countTokens(joined)
|
||||||
|
if err != nil {
|
||||||
|
u.logger.Warn("Failed to count tokens for aggregation, proceeding anyway", log.Error(err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if tokens <= summaryAggregateLimit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// If still too large, aggregate in batches
|
||||||
|
u.logger.Debug("aggregating summaries in batches", log.Int("current_summaries", len(summaries)), log.Int("tokens", tokens))
|
||||||
|
batchSize := 2
|
||||||
|
newSummaries := make([]string, 0, (len(summaries)+batchSize-1)/batchSize)
|
||||||
|
for i := 0; i < len(summaries); i += batchSize {
|
||||||
|
end := i + batchSize
|
||||||
|
if end > len(summaries) {
|
||||||
|
end = len(summaries)
|
||||||
|
}
|
||||||
|
batch := strings.Join(summaries[i:end], "\n\n")
|
||||||
|
summary, err := u.requestSummary(ctx, chatModel, name, batch)
|
||||||
|
if err != nil {
|
||||||
|
u.logger.Error("Failed to aggregate summary batch", log.Int("batch_start", i), log.Error(err))
|
||||||
|
// Fallback: use the first summary in the batch
|
||||||
|
newSummaries = append(newSummaries, summaries[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newSummaries = append(newSummaries, summary)
|
||||||
|
}
|
||||||
|
summaries = newSummaries
|
||||||
|
}
|
||||||
|
|
||||||
joined := strings.Join(summaries, "\n\n")
|
joined := strings.Join(summaries, "\n\n")
|
||||||
return u.requestSummary(ctx, chatModel, name, joined)
|
finalSummary, err := u.requestSummary(ctx, chatModel, name, joined)
|
||||||
|
if err != nil {
|
||||||
|
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
|
||||||
|
// Fallback: return the joined summaries directly
|
||||||
|
if len(joined) > 500 {
|
||||||
|
return joined[:500] + "...", nil
|
||||||
|
}
|
||||||
|
return joined, nil
|
||||||
|
}
|
||||||
|
return finalSummary, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *LLMUsecase) trimThinking(summary string) string {
|
func (u *LLMUsecase) trimThinking(summary string) string {
|
||||||
|
|
@ -258,6 +301,15 @@ func (u *LLMUsecase) requestSummary(ctx context.Context, chatModel model.BaseCha
|
||||||
return strings.TrimSpace(u.trimThinking(summary)), nil
|
return strings.TrimSpace(u.trimThinking(summary)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *LLMUsecase) countTokens(text string) (int, error) {
|
||||||
|
encoding, err := tiktoken.GetEncoding("cl100k_base")
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get encoding: %w", err)
|
||||||
|
}
|
||||||
|
tokens := encoding.Encode(text, nil, nil)
|
||||||
|
return len(tokens), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
||||||
if maxTokens <= 0 {
|
if maxTokens <= 0 {
|
||||||
return nil, fmt.Errorf("maxTokens must be greater than 0")
|
return nil, fmt.Errorf("maxTokens must be greater than 0")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue