mirror of https://github.com/ollama/ollama.git
tools: relax JSON parse constraints for tool calling (#10872)
This commit is contained in:
parent
aea6fb9b58
commit
066d0f4746
|
@ -17,15 +17,14 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Parser struct {
|
type Parser struct {
|
||||||
parseLeadingJSON bool
|
greedyParseJSON bool
|
||||||
prefix string
|
prefix string
|
||||||
prefixFound bool
|
prefixFound bool
|
||||||
tmpl gotmpl.Template
|
tmpl gotmpl.Template
|
||||||
sb strings.Builder
|
sb strings.Builder
|
||||||
index int
|
index int
|
||||||
name string
|
name string
|
||||||
arguments string
|
arguments string
|
||||||
done bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
|
@ -176,14 +175,6 @@ func (p *Parser) checkPrefix(s string) (string, error) {
|
||||||
// - tools: Any parsed tool calls
|
// - tools: Any parsed tool calls
|
||||||
// - content: Non-tool call content
|
// - content: Non-tool call content
|
||||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||||
if p.done {
|
|
||||||
if p.index == 0 {
|
|
||||||
// Return original string if no tool calls found at start
|
|
||||||
return nil, s
|
|
||||||
}
|
|
||||||
// Return empty if no tool calls found after start
|
|
||||||
return nil, ""
|
|
||||||
}
|
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
s = p.sb.String()
|
s = p.sb.String()
|
||||||
|
|
||||||
|
@ -195,7 +186,7 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||||
if !p.parseLeadingJSON && !p.prefixFound {
|
if !p.greedyParseJSON && !p.prefixFound {
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
return nil, s
|
return nil, s
|
||||||
}
|
}
|
||||||
|
@ -206,10 +197,9 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||||
return nil, ""
|
return nil, ""
|
||||||
}
|
}
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
// Do not try parsing leading JSON if JSON not found
|
// Only do greedy JSON parsing if there is no prefix from template
|
||||||
p.parseLeadingJSON = false
|
if p.prefix != "" {
|
||||||
if p.prefix == "" {
|
p.greedyParseJSON = false
|
||||||
p.done = true
|
|
||||||
}
|
}
|
||||||
if p.index != 0 && p.prefix == "" {
|
if p.index != 0 && p.prefix == "" {
|
||||||
return nil, ""
|
return nil, ""
|
||||||
|
@ -253,11 +243,11 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Parser{
|
return &Parser{
|
||||||
tmpl: *tt,
|
tmpl: *tt,
|
||||||
sb: strings.Builder{},
|
sb: strings.Builder{},
|
||||||
prefix: tp,
|
prefix: tp,
|
||||||
parseLeadingJSON: true,
|
greedyParseJSON: true,
|
||||||
name: name,
|
name: name,
|
||||||
arguments: arguments,
|
arguments: arguments,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -536,11 +536,18 @@ func TestParseToolCalls(t *testing.T) {
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model without prefix in template, prefix in output",
|
name: "model without prefix in template, prefix in output, multiple tool calls in list",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
expectedTokens: `<tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without prefix in template, prefix in output, individual tool calls",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `<tool_call> {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `<tool_call>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model with prefix in template, no prefix in output, tokens before",
|
name: "model with prefix in template, no prefix in output, tokens before",
|
||||||
|
@ -567,15 +574,37 @@ func TestParseToolCalls(t *testing.T) {
|
||||||
name: "model without prefix in template, no prefix in output, tokens before",
|
name: "model without prefix in template, no prefix in output, tokens before",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
expectedTokens: `some tokens before`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "model without prefix in template, prefix in output, tokens after",
|
name: "model without prefix in template, prefix in output, tokens after",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `<tool_call>
|
||||||
|
[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: `<tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model without without prefix, match all jsons",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "model outputs some text",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model flushes tokens if tool call doesn't match",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model flushes tokens if tool call doesn't match array",
|
||||||
|
model: "llama3.2",
|
||||||
|
output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue