mirror of https://github.com/ollama/ollama.git
135 lines
3.4 KiB
Go
135 lines
3.4 KiB
Go
|
package thinking
|
|||
|
|
|||
|
import (
|
|||
|
"strings"
|
|||
|
"text/template"
|
|||
|
"text/template/parse"
|
|||
|
)
|
|||
|
|
|||
|
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
|
|||
|
if n == nil {
|
|||
|
return
|
|||
|
}
|
|||
|
shouldContinue := enterFn(n)
|
|||
|
if !shouldContinue {
|
|||
|
return
|
|||
|
}
|
|||
|
switch x := n.(type) {
|
|||
|
case *parse.ListNode:
|
|||
|
for _, c := range x.Nodes {
|
|||
|
templateVisit(c, enterFn, exitFn)
|
|||
|
}
|
|||
|
case *parse.BranchNode:
|
|||
|
if x.Pipe != nil {
|
|||
|
templateVisit(x.Pipe, enterFn, exitFn)
|
|||
|
}
|
|||
|
if x.List != nil {
|
|||
|
templateVisit(x.List, enterFn, exitFn)
|
|||
|
}
|
|||
|
if x.ElseList != nil {
|
|||
|
templateVisit(x.ElseList, enterFn, exitFn)
|
|||
|
}
|
|||
|
case *parse.ActionNode:
|
|||
|
templateVisit(x.Pipe, enterFn, exitFn)
|
|||
|
case *parse.WithNode:
|
|||
|
templateVisit(&x.BranchNode, enterFn, exitFn)
|
|||
|
case *parse.RangeNode:
|
|||
|
templateVisit(&x.BranchNode, enterFn, exitFn)
|
|||
|
case *parse.IfNode:
|
|||
|
templateVisit(&x.BranchNode, enterFn, exitFn)
|
|||
|
case *parse.TemplateNode:
|
|||
|
templateVisit(x.Pipe, enterFn, exitFn)
|
|||
|
case *parse.PipeNode:
|
|||
|
for _, c := range x.Cmds {
|
|||
|
templateVisit(c, enterFn, exitFn)
|
|||
|
}
|
|||
|
case *parse.CommandNode:
|
|||
|
for _, a := range x.Args {
|
|||
|
templateVisit(a, enterFn, exitFn)
|
|||
|
}
|
|||
|
// text, field, number, etc. are leaves – nothing to recurse into
|
|||
|
}
|
|||
|
if exitFn != nil {
|
|||
|
exitFn(n)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// InferTags uses a heuristic to infer the tags that surround thinking traces:
|
|||
|
// We look for a range node that iterates over "Messages" and then look for a
|
|||
|
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
|
|||
|
// ListNode and take the first and last TextNodes as the opening and closing
|
|||
|
// tags.
|
|||
|
func InferTags(t *template.Template) (string, string) {
|
|||
|
ancestors := []parse.Node{}
|
|||
|
|
|||
|
openingTag := ""
|
|||
|
closingTag := ""
|
|||
|
|
|||
|
enterFn := func(n parse.Node) bool {
|
|||
|
ancestors = append(ancestors, n)
|
|||
|
|
|||
|
switch x := n.(type) {
|
|||
|
case *parse.FieldNode:
|
|||
|
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
|
|||
|
var mostRecentRange *parse.RangeNode
|
|||
|
for i := len(ancestors) - 1; i >= 0; i-- {
|
|||
|
if r, ok := ancestors[i].(*parse.RangeNode); ok {
|
|||
|
mostRecentRange = r
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
|
|||
|
return true
|
|||
|
}
|
|||
|
|
|||
|
// TODO(drifkin): to be more robust, check that it's in the action
|
|||
|
// part, not the `if`'s pipeline part. We do match on the nearest list
|
|||
|
// that starts and ends with text nodes, which makes this not strictly
|
|||
|
// necessary for our heuristic
|
|||
|
|
|||
|
// go up to the nearest ancestor that is a *parse.ListNode
|
|||
|
for i := len(ancestors) - 1; i >= 0; i-- {
|
|||
|
if l, ok := ancestors[i].(*parse.ListNode); ok {
|
|||
|
firstNode := l.Nodes[0]
|
|||
|
if t, ok := firstNode.(*parse.TextNode); ok {
|
|||
|
openingTag = strings.TrimSpace(t.String())
|
|||
|
}
|
|||
|
lastNode := l.Nodes[len(l.Nodes)-1]
|
|||
|
if t, ok := lastNode.(*parse.TextNode); ok {
|
|||
|
closingTag = strings.TrimSpace(t.String())
|
|||
|
}
|
|||
|
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return true
|
|||
|
}
|
|||
|
|
|||
|
exitFn := func(n parse.Node) {
|
|||
|
ancestors = ancestors[:len(ancestors)-1]
|
|||
|
}
|
|||
|
|
|||
|
templateVisit(t.Root, enterFn, exitFn)
|
|||
|
|
|||
|
return openingTag, closingTag
|
|||
|
}
|
|||
|
|
|||
|
// checks to see if the given field name is present in the pipeline of the given range node
|
|||
|
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
|
|||
|
found := false
|
|||
|
enterFn := func(n parse.Node) bool {
|
|||
|
switch x := n.(type) {
|
|||
|
case *parse.FieldNode:
|
|||
|
if x.Ident[0] == field {
|
|||
|
found = true
|
|||
|
}
|
|||
|
}
|
|||
|
return true
|
|||
|
}
|
|||
|
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
|
|||
|
return found
|
|||
|
}
|