remove extra context, fix non thinking constraining

This commit is contained in:
ParthSareen 2025-10-06 13:55:43 -07:00
parent 2047dd2b38
commit b3d8274741
2 changed files with 128 additions and 8 deletions

View File

@ -1967,12 +1967,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
ctx, cancel := context.WithCancel(c.Request.Context())
structuredOutputsStarted := false
ch := make(chan any)
go func() {
defer close(ch)
defer cancel()
applyStructuredOutputs := false
@ -1980,12 +1978,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
var tb strings.Builder
currentFormat := req.Format
if req.Format != nil && !structuredOutputsStarted {
// set format to nil if a parser exists and structured outputs haven't started
if req.Format != nil && !structuredOutputsStarted && (thinkingState != nil || builtinParser != nil) {
currentFormat = nil
}
// sets up new context given parent context
requestCtx, requestCancel := context.WithCancel(ctx)
// sets up new context given parent context per request
requestCtx, requestCancel := context.WithCancel(c.Request.Context())
err := r.Completion(requestCtx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
@ -2081,7 +2080,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch <- res
})
if err != nil {
if applyStructuredOutputs && !structuredOutputsStarted && (errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "context canceled")) && ctx.Err() == nil {
if applyStructuredOutputs && !structuredOutputsStarted && (errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "context canceled")) && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
slog.Error("chat completion error", "error", err.Error())

View File

@ -594,6 +594,126 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
}
})
t.Run("structured outputs format propagation", func(t *testing.T) {
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
streamRequest := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-chat-format",
From: "test",
Stream: &streamRequest,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 creating test-chat-format, got %d", w.Code)
}
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-chat-format-parser",
From: "test-chat-format",
Parser: "passthrough",
Stream: &streamRequest,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 creating test-chat-format-parser, got %d", w.Code)
}
t.Run("without parser applies format", func(t *testing.T) {
var requests []llm.CompletionRequest
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
requests = append(requests, r)
fn(llm.CompletionResponse{
Content: `{"answer":"ok"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
}
t.Cleanup(func() { mock.CompletionFn = nil })
respRecorder := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-chat-format",
Messages: []api.Message{
{Role: "user", Content: "Give me json"},
},
Format: format,
Stream: &streamRequest,
})
if respRecorder.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", respRecorder.Code)
}
if len(requests) != 1 {
t.Fatalf("expected 1 completion call, got %d", len(requests))
}
if diff := cmp.Diff(string(requests[0].Format), string(format)); diff != "" {
t.Errorf("format mismatch (-got +want):\n%s", diff)
}
var resp api.ChatResponse
if err := json.NewDecoder(respRecorder.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Content != `{"answer":"ok"}` {
t.Errorf("expected structured response, got %q", resp.Message.Content)
}
})
t.Run("with parser missing thinking skips format", func(t *testing.T) {
var requests []llm.CompletionRequest
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
requests = append(requests, r)
fn(llm.CompletionResponse{
Content: `{"answer":"ok"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
}
t.Cleanup(func() { mock.CompletionFn = nil })
respRecorder := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-chat-format-parser",
Messages: []api.Message{
{Role: "user", Content: "Give me json"},
},
Format: format,
Stream: &streamRequest,
})
if respRecorder.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", respRecorder.Code)
}
if len(requests) != 1 {
t.Fatalf("expected 1 completion call, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected format to be nil when parser is present without thinking")
}
var resp api.ChatResponse
if err := json.NewDecoder(respRecorder.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Content != `{"answer":"ok"}` {
t.Errorf("expected structured response passthrough, got %q", resp.Message.Content)
}
})
})
}
func TestGenerate(t *testing.T) {
@ -968,6 +1088,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
@ -1208,7 +1329,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
wg.Add(2)
format := json.RawMessage([]byte(`{"type":"object","properties":{"answer":{"type":"string"}}}`))
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
@ -1311,7 +1432,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
wg.Add(2)
format := json.RawMessage([]byte(`{"type":"object","properties":{"answer":{"type":"string"}}}`))
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()