mirror of https://github.com/chaitin/PandaWiki.git
Compare commits
7 Commits
9c02e5d93b
...
5078e93a4a
| Author | SHA1 | Date |
|---|---|---|
|
|
5078e93a4a | |
|
|
f7c0fe273b | |
|
|
a6f4688b88 | |
|
|
575f51f0ea | |
|
|
83f6853716 | |
|
|
3dae8e8d01 | |
|
|
2e1e1848c4 |
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -16,8 +15,6 @@ import (
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
"github.com/pkoukk/tiktoken-go"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/samber/lo/parallel"
|
|
||||||
"golang.org/x/sync/semaphore"
|
|
||||||
|
|
||||||
"github.com/chaitin/panda-wiki/config"
|
"github.com/chaitin/panda-wiki/config"
|
||||||
"github.com/chaitin/panda-wiki/domain"
|
"github.com/chaitin/panda-wiki/domain"
|
||||||
|
|
@ -39,6 +36,11 @@ type LLMUsecase struct {
|
||||||
modelkit *modelkit.ModelKit
|
modelkit *modelkit.ModelKit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
summaryChunkTokenLimit = 30720 // 30KB tokens per chunk
|
||||||
|
summaryMaxChunks = 4 // max chunks to process for summary
|
||||||
|
)
|
||||||
|
|
||||||
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
func NewLLMUsecase(config *config.Config, rag rag.RAGService, conversationRepo *pg.ConversationRepository, kbRepo *pg.KnowledgeBaseRepository, nodeRepo *pg.NodeRepository, modelRepo *pg.ModelRepository, promptRepo *pg.PromptRepo, logger *log.Logger) *LLMUsecase {
|
||||||
tiktoken.SetBpeLoader(&utils.Localloader{})
|
tiktoken.SetBpeLoader(&utils.Localloader{})
|
||||||
modelkit := modelkit.NewModelKit(logger.Logger)
|
modelkit := modelkit.NewModelKit(logger.Logger)
|
||||||
|
|
@ -197,52 +199,59 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks, err := u.SplitByTokenLimit(content, int(math.Floor(1024*32*0.95)))
|
chunks, err := u.SplitByTokenLimit(content, summaryChunkTokenLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
sem := semaphore.NewWeighted(int64(10))
|
if len(chunks) > summaryMaxChunks {
|
||||||
summaries := parallel.Map(chunks, func(chunk string, _ int) string {
|
u.logger.Debug("trim summary chunks for large document", log.String("node", name), log.Int("original_chunks", len(chunks)), log.Int("used_chunks", summaryMaxChunks))
|
||||||
if err := sem.Acquire(ctx, 1); err != nil {
|
chunks = chunks[:summaryMaxChunks]
|
||||||
u.logger.Error("Failed to acquire semaphore for chunk: ", log.Error(err))
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer sem.Release(1)
|
|
||||||
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: "你是文档总结助手,请根据文档内容总结出文档的摘要。摘要是纯文本,应该简洁明了,不要超过160个字。",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: "user",
|
|
||||||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, chunk),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
u.logger.Error("Failed to generate summary for chunk: ", log.Error(err))
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(summary, "<think>") {
|
|
||||||
// remove <think> body </think>
|
|
||||||
endIndex := strings.Index(summary, "</think>")
|
|
||||||
if endIndex != -1 {
|
|
||||||
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return summary
|
|
||||||
})
|
|
||||||
// 使用lo.Filter处理错误
|
|
||||||
defeatSummary := lo.Filter(summaries, func(summary string, index int) bool {
|
|
||||||
return summary == ""
|
|
||||||
})
|
|
||||||
if len(defeatSummary) > 0 {
|
|
||||||
return "", fmt.Errorf("failed to generate summaries for all chunks: %d/%d", len(defeatSummary), len(chunks))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
contents, err := u.SplitByTokenLimit(strings.Join(summaries, "\n\n"), int(math.Floor(1024*32*0.95)))
|
summaries := make([]string, 0, len(chunks))
|
||||||
|
for idx, chunk := range chunks {
|
||||||
|
summary, err := u.requestSummary(ctx, chatModel, name, chunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
u.logger.Error("Failed to generate summary for chunk", log.Int("chunk_index", idx), log.Error(err))
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
if summary == "" {
|
||||||
|
u.logger.Warn("Empty summary returned for chunk", log.Int("chunk_index", idx))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
summaries = append(summaries, summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(summaries) == 0 {
|
||||||
|
return "", fmt.Errorf("failed to generate summary for document %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join all summaries and generate final summary
|
||||||
|
joined := strings.Join(summaries, "\n\n")
|
||||||
|
finalSummary, err := u.requestSummary(ctx, chatModel, name, joined)
|
||||||
|
if err != nil {
|
||||||
|
u.logger.Error("Failed to generate final summary, using aggregated summaries", log.Error(err))
|
||||||
|
// Fallback: return the joined summaries directly
|
||||||
|
if len(joined) > 500 {
|
||||||
|
return joined[:500] + "...", nil
|
||||||
|
}
|
||||||
|
return joined, nil
|
||||||
|
}
|
||||||
|
return finalSummary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *LLMUsecase) trimThinking(summary string) string {
|
||||||
|
if !strings.HasPrefix(summary, "<think>") {
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
endIndex := strings.Index(summary, "</think>")
|
||||||
|
if endIndex == -1 {
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(summary[endIndex+len("</think>"):])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *LLMUsecase) requestSummary(ctx context.Context, chatModel model.BaseChatModel, name, content string) (string, error) {
|
||||||
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
|
summary, err := u.Generate(ctx, chatModel, []*schema.Message{
|
||||||
{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
|
|
@ -250,20 +259,13 @@ func (u *LLMUsecase) SummaryNode(ctx context.Context, model *domain.Model, name,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, contents[0]),
|
Content: fmt.Sprintf("文档名称:%s\n文档内容:%s", name, content),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(summary, "<think>") {
|
return strings.TrimSpace(u.trimThinking(summary)), nil
|
||||||
// remove <think> body </think>
|
|
||||||
endIndex := strings.Index(summary, "</think>")
|
|
||||||
if endIndex != -1 {
|
|
||||||
summary = strings.TrimSpace(summary[endIndex+8:]) // 8 is length of "</think>"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return summary, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue