mirror of https://github.com/ollama/ollama.git
remove extra context, fix non thinking constraining
This commit is contained in:
parent
2047dd2b38
commit
b3d8274741
|
@ -1967,12 +1967,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
structuredOutputsStarted := false
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer cancel()
|
||||
|
||||
applyStructuredOutputs := false
|
||||
|
||||
|
@ -1980,12 +1978,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
var tb strings.Builder
|
||||
|
||||
currentFormat := req.Format
|
||||
if req.Format != nil && !structuredOutputsStarted {
|
||||
// set format to nil if a parser exists and structured outputs haven't started
|
||||
if req.Format != nil && !structuredOutputsStarted && (thinkingState != nil || builtinParser != nil) {
|
||||
currentFormat = nil
|
||||
}
|
||||
|
||||
// sets up new context given parent context
|
||||
requestCtx, requestCancel := context.WithCancel(ctx)
|
||||
// sets up new context given parent context per request
|
||||
requestCtx, requestCancel := context.WithCancel(c.Request.Context())
|
||||
err := r.Completion(requestCtx, llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
|
@ -2081,7 +2080,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
ch <- res
|
||||
})
|
||||
if err != nil {
|
||||
if applyStructuredOutputs && !structuredOutputsStarted && (errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "context canceled")) && ctx.Err() == nil {
|
||||
if applyStructuredOutputs && !structuredOutputsStarted && (errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "context canceled")) && c.Request.Context().Err() == nil {
|
||||
// only ignores error if it's a context cancellation due to setting structured outputs
|
||||
} else {
|
||||
slog.Error("chat completion error", "error", err.Error())
|
||||
|
|
|
@ -594,6 +594,126 @@ func TestGenerateChat(t *testing.T) {
|
|||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs format propagation", func(t *testing.T) {
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
streamRequest := false
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-chat-format",
|
||||
From: "test",
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200 creating test-chat-format, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-chat-format-parser",
|
||||
From: "test-chat-format",
|
||||
Parser: "passthrough",
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200 creating test-chat-format-parser, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("without parser applies format", func(t *testing.T) {
|
||||
var requests []llm.CompletionRequest
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
requests = append(requests, r)
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"ok"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
t.Cleanup(func() { mock.CompletionFn = nil })
|
||||
|
||||
respRecorder := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-chat-format",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Give me json"},
|
||||
},
|
||||
Format: format,
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if respRecorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", respRecorder.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 completion call, got %d", len(requests))
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(string(requests[0].Format), string(format)); diff != "" {
|
||||
t.Errorf("format mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(respRecorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Content != `{"answer":"ok"}` {
|
||||
t.Errorf("expected structured response, got %q", resp.Message.Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with parser missing thinking skips format", func(t *testing.T) {
|
||||
var requests []llm.CompletionRequest
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
requests = append(requests, r)
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"ok"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
t.Cleanup(func() { mock.CompletionFn = nil })
|
||||
|
||||
respRecorder := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-chat-format-parser",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Give me json"},
|
||||
},
|
||||
Format: format,
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if respRecorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", respRecorder.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected 1 completion call, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected format to be nil when parser is present without thinking")
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(respRecorder.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Content != `{"answer":"ok"}` {
|
||||
t.Errorf("expected structured response passthrough, got %q", resp.Message.Content)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
|
@ -968,6 +1088,7 @@ func TestGenerate(t *testing.T) {
|
|||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
|
@ -1208,7 +1329,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
|||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage([]byte(`{"type":"object","properties":{"answer":{"type":"string"}}}`))
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
@ -1311,7 +1432,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
|||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage([]byte(`{"type":"object","properties":{"answer":{"type":"string"}}}`))
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
|
Loading…
Reference in New Issue