mirror of https://github.com/ollama/ollama.git
tests: add tool calling integration test (#12232)
This commit is contained in:
parent
6745182885
commit
20b53eaa72
|
@ -410,3 +410,99 @@ func TestAPIEmbeddings(t *testing.T) {
|
|||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "qwen3:0.6b"
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue