diff --git a/middleware/openai.go b/middleware/openai.go new file mode 100644 index 000000000..826a2111b --- /dev/null +++ b/middleware/openai.go @@ -0,0 +1,424 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/openai" +) + +type BaseWriter struct { + gin.ResponseWriter +} + +type ChatWriter struct { + stream bool + streamOptions *openai.StreamOptions + id string + toolCallSent bool + BaseWriter +} + +type CompleteWriter struct { + stream bool + streamOptions *openai.StreamOptions + id string + BaseWriter +} + +type ListWriter struct { + BaseWriter +} + +type RetrieveWriter struct { + BaseWriter + model string +} + +type EmbedWriter struct { + BaseWriter + model string +} + +func (w *BaseWriter) writeError(data []byte) (int, error) { + var serr api.StatusError + err := json.Unmarshal(data, &serr) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error())) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ChatWriter) writeResponse(data []byte) (int, error) { + var chatResponse api.ChatResponse + err := json.Unmarshal(data, &chatResponse) + if err != nil { + return 0, err + } + + // chat chunk + if w.stream { + c := openai.ToChunk(w.id, chatResponse, w.toolCallSent) + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 { + w.toolCallSent = true + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if chatResponse.Done { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + u := openai.ToUsage(chatResponse) + c.Usage = &u + c.Choices = []openai.ChunkChoice{} + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + } + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // chat completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ChatWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *CompleteWriter) writeResponse(data []byte) (int, error) { + var generateResponse api.GenerateResponse + err := json.Unmarshal(data, &generateResponse) + if err != nil { + return 0, err + } + + // completion chunk + if w.stream { + c := openai.ToCompleteChunk(w.id, generateResponse) + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + c.Usage = &openai.Usage{} + } + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if generateResponse.Done { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + u := openai.ToUsageGenerate(generateResponse) + c.Usage = &u + c.Choices = []openai.CompleteChunkChoice{} + d, err := json.Marshal(c) + if err != nil { + return 0, err + } + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + } + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *CompleteWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *ListWriter) writeResponse(data []byte) (int, error) { + var listResponse api.ListResponse + err := json.Unmarshal(data, &listResponse) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ListWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { + var showResponse api.ShowResponse + err := json.Unmarshal(data, &showResponse) + if err != nil { + return 0, err + } + + // retrieve completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *RetrieveWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func (w *EmbedWriter) writeResponse(data []byte) (int, error) { + var embedResponse api.EmbedResponse + err := json.Unmarshal(data, &embedResponse) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *EmbedWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(data) + } + + return w.writeResponse(data) +} + +func ListMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + w := &ListWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + } + + c.Writer = w + + c.Next() + } +} + +func RetrieveMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &RetrieveWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: c.Param("model"), + } + + c.Writer = w + + c.Next() + } +} + +func CompletionsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.CompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + var b bytes.Buffer + genReq, err := openai.FromCompleteRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if err := json.NewEncoder(&b).Encode(genReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &CompleteWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, + } + + c.Writer = w + c.Next() + } +} + +func EmbeddingsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.EmbedRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Input == "" { + req.Input = []string{""} + } + + if req.Input == nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input")) + return + } + + if v, ok := req.Input.([]any); ok && len(v) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &EmbedWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: req.Model, + } + + c.Writer = w + + c.Next() + } +} + +func ChatMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req openai.ChatCompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if len(req.Messages) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'")) + return + } + + var b bytes.Buffer + + chatReq, err := openai.FromChatRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error())) + return + } + + if err := json.NewEncoder(&b).Encode(chatReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &ChatWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, + } + + c.Writer = w + + c.Next() + } +} diff --git a/middleware/openai_test.go b/middleware/openai_test.go new file mode 100644 index 000000000..a78ee8b91 --- /dev/null +++ b/middleware/openai_test.go @@ -0,0 +1,928 @@ +package middleware + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/openai" +) + +const ( + prefix = `data:image/jpeg;base64,` + image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` +) + +var ( + False = false + True = true +) + +func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + err := json.Unmarshal(bodyBytes, capturedRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") + } + c.Next() + } +} + +func TestChatMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.ChatRequest + err openai.ErrorResponse + } + + var capturedRequest *api.ChatRequest + + testCases := []testCase{ + { + name: "chat handler", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with options", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "max_tokens": 999, + "seed": 123, + "stop": ["\n", "stop"], + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + "response_format": {"type": "json_object"} + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "num_predict": 999.0, // float because JSON doesn't distinguish between float and int + "seed": 123.0, + "stop": []any{"\n", "stop"}, + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + }, + Format: json.RawMessage(`"json"`), + Stream: &True, + }, + }, + { + name: "chat handler with streaming usage", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "stream_options": {"include_usage": true}, + "max_tokens": 999, + "seed": 123, + "stop": ["\n", "stop"], + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + "response_format": {"type": "json_object"} + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "num_predict": 999.0, // float because JSON doesn't distinguish between float and int + "seed": 123.0, + "stop": []any{"\n", "stop"}, + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + }, + Format: json.RawMessage(`"json"`), + Stream: &True, + }, + }, + { + name: "chat handler with image content", + body: `{ + "model": "test-model", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello" + }, + { + "type": "image_url", + "image_url": { + "url": "` + prefix + image + `" + } + } + ] + } + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + { + Role: "user", + Images: []api.ImageData{ + func() []byte { + img, _ := base64.StdEncoding.DecodeString(image) + return img + }(), + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools and content", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + Content: "Let's see what the weather is like in Paris", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools and empty content", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools and thinking content", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + Thinking: "Let's see what the weather is like in Paris", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "tool response with call ID", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, + {"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + { + Role: "tool", + Content: "The weather in Paris is 20 degrees Celsius", + ToolName: "get_current_weather", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "tool response with name", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, + {"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "location": "Paris, France", + "format": "celsius", + }, + }, + }, + }, + }, + { + Role: "tool", + Content: "The weather in Paris is 20 degrees Celsius", + ToolName: "get_current_weather", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with streaming tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris?"} + ], + "stream": true, + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": { + "type": "string", + "description": "The city and state" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + } + } + } + }] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris?", + }, + }, + Tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]api.ToolProperty `json:"properties"` + }{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state", + }, + "unit": { + Type: api.PropertyType{"string"}, + Enum: []any{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &True, + }, + }, + { + name: "chat handler error forwarding", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": 2} + ] + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "invalid message content type: float64", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/chat", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + defer func() { capturedRequest = nil }() + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var errResp openai.ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + return + } + if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { + t.Fatalf("requests did not match: %+v", diff) + } + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) + } + }) + } +} + +func TestCompletionsMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.GenerateRequest + err openai.ErrorResponse + } + + var capturedRequest *api.GenerateRequest + + testCases := []testCase{ + { + name: "completions handler", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &False, + }, + }, + { + name: "completions handler stream", + body: `{ + "model": "test-model", + "prompt": "Hello", + "stream": true, + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &True, + }, + }, + { + name: "completions handler stream with usage", + body: `{ + "model": "test-model", + "prompt": "Hello", + "stream": true, + "stream_options": {"include_usage": true}, + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 0.8, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &True, + }, + }, + { + name: "completions handler error forwarding", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": null, + "stop": [1, 2], + "suffix": "suffix" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "invalid type for 'stop' field: float64", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var errResp openai.ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } + + capturedRequest = nil + }) + } +} + +func TestEmbeddingsMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.EmbedRequest + err openai.ErrorResponse + } + + var capturedRequest *api.EmbedRequest + + testCases := []testCase{ + { + name: "embed handler single input", + body: `{ + "input": "Hello", + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: "Hello", + Model: "test-model", + }, + }, + { + name: "embed handler batch input", + body: `{ + "input": ["Hello", "World"], + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: []any{"Hello", "World"}, + Model: "test-model", + }, + }, + { + name: "embed handler error forwarding", + body: `{ + "model": "test-model" + }`, + err: openai.ErrorResponse{ + Error: openai.Error{ + Message: "invalid input", + Type: "invalid_request_error", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/embed", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var errResp openai.ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } + + capturedRequest = nil + }) + } +} + +func TestListMiddleware(t *testing.T) { + type testCase struct { + name string + endpoint func(c *gin.Context) + resp string + } + + testCases := []testCase{ + { + name: "list handler", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{ + Models: []api.ListModelResponse{ + { + Name: "test-model", + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), + }, + }, + }) + }, + resp: `{ + "object": "list", + "data": [ + { + "id": "test-model", + "object": "model", + "created": 1686935002, + "owned_by": "library" + } + ] + }`, + }, + { + name: "list handler empty output", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{}) + }, + resp: `{ + "object": "list", + "data": null + }`, + }, + } + + gin.SetMode(gin.TestMode) + + for _, tc := range testCases { + router := gin.New() + router.Use(ListMiddleware()) + router.Handle(http.MethodGet, "/api/tags", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } + + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } + } +} + +func TestRetrieveMiddleware(t *testing.T) { + type testCase struct { + name string + endpoint func(c *gin.Context) + resp string + } + + testCases := []testCase{ + { + name: "retrieve handler", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ShowResponse{ + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), + }) + }, + resp: `{ + "id":"test-model", + "object":"model", + "created":1686935002, + "owned_by":"library"} + `, + }, + { + name: "retrieve handler error forwarding", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) + }, + resp: `{ + "error": { + "code": null, + "message": "model not found", + "param": null, + "type": "api_error" + } + }`, + }, + } + + gin.SetMode(gin.TestMode) + + for _, tc := range testCases { + router := gin.New() + router.Use(RetrieveMiddleware()) + router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } + + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } + } +} diff --git a/openai/openai.go b/openai/openai.go index 7ef5ac6de..55e55e97c 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -1,21 +1,17 @@ -// openai package provides middleware for partial compatibility with the OpenAI REST API +// openai package provides core transformation logic for partial compatibility with the OpenAI REST API package openai import ( - "bytes" "encoding/base64" "encoding/json" "errors" "fmt" - "io" "log/slog" "math/rand" "net/http" "strings" "time" - "github.com/gin-gonic/gin" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/types/model" ) @@ -220,11 +216,12 @@ func NewError(code int, message string) ErrorResponse { return ErrorResponse{Error{Type: etype, Message: message}} } -func toUsage(r api.ChatResponse) Usage { +// ToUsage converts an api.ChatResponse to Usage +func ToUsage(r api.ChatResponse) Usage { return Usage{ - PromptTokens: r.PromptEvalCount, - CompletionTokens: r.EvalCount, - TotalTokens: r.PromptEvalCount + r.EvalCount, + PromptTokens: r.Metrics.PromptEvalCount, + CompletionTokens: r.Metrics.EvalCount, + TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount, } } @@ -256,7 +253,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall { return toolCalls } -func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { +// ToChatCompletion converts an api.ChatResponse to ChatCompletion +func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { toolCalls := toToolCalls(r.Message.ToolCalls) return ChatCompletion{ Id: id, @@ -276,12 +274,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), - }}, Usage: toUsage(r), + }}, Usage: ToUsage(r), DebugInfo: r.DebugInfo, } } -func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { +// ToChunk converts an api.ChatResponse to ChatCompletionChunk +func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { toolCalls := toToolCalls(r.Message.ToolCalls) return ChatCompletionChunk{ Id: id, @@ -305,15 +304,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu } } -func toUsageGenerate(r api.GenerateResponse) Usage { +// ToUsageGenerate converts an api.GenerateResponse to Usage +func ToUsageGenerate(r api.GenerateResponse) Usage { return Usage{ - PromptTokens: r.PromptEvalCount, - CompletionTokens: r.EvalCount, - TotalTokens: r.PromptEvalCount + r.EvalCount, + PromptTokens: r.Metrics.PromptEvalCount, + CompletionTokens: r.Metrics.EvalCount, + TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount, } } -func toCompletion(id string, r api.GenerateResponse) Completion { +// ToCompletion converts an api.GenerateResponse to Completion +func ToCompletion(id string, r api.GenerateResponse) Completion { return Completion{ Id: id, Object: "text_completion", @@ -330,11 +331,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion { return nil }(r.DoneReason), }}, - Usage: toUsageGenerate(r), + Usage: ToUsageGenerate(r), } } -func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { +// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk +func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { return CompletionChunk{ Id: id, Object: "text_completion", @@ -354,7 +356,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { } } -func toListCompletion(r api.ListResponse) ListCompletion { +// ToListCompletion converts an api.ListResponse to ListCompletion +func ToListCompletion(r api.ListResponse) ListCompletion { var data []Model for _, m := range r.Models { data = append(data, Model{ @@ -371,7 +374,8 @@ func toListCompletion(r api.ListResponse) ListCompletion { } } -func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { +// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList +func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { if r.Embeddings != nil { var data []Embedding for i, e := range r.Embeddings { @@ -396,7 +400,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { return EmbeddingList{} } -func toModel(r api.ShowResponse, m string) Model { +// ToModel converts an api.ShowResponse to Model +func ToModel(r api.ShowResponse, m string) Model { return Model{ Id: m, Object: "model", @@ -405,7 +410,8 @@ func toModel(r api.ShowResponse, m string) Model { } } -func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { +// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest +func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { var messages []api.Message for _, msg := range r.Messages { toolName := "" @@ -609,7 +615,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) { return apiToolCalls, nil } -func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { +// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest +func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { options := make(map[string]any) switch stop := r.Stop.(type) { @@ -660,413 +667,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { DebugRenderOnly: r.DebugRenderOnly, }, nil } - -type BaseWriter struct { - gin.ResponseWriter -} - -type ChatWriter struct { - stream bool - streamOptions *StreamOptions - id string - toolCallSent bool - BaseWriter -} - -type CompleteWriter struct { - stream bool - streamOptions *StreamOptions - id string - BaseWriter -} - -type ListWriter struct { - BaseWriter -} - -type RetrieveWriter struct { - BaseWriter - model string -} - -type EmbedWriter struct { - BaseWriter - model string -} - -func (w *BaseWriter) writeError(data []byte) (int, error) { - var serr api.StatusError - err := json.Unmarshal(data, &serr) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error())) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *ChatWriter) writeResponse(data []byte) (int, error) { - var chatResponse api.ChatResponse - err := json.Unmarshal(data, &chatResponse) - if err != nil { - return 0, err - } - - // chat chunk - if w.stream { - c := toChunk(w.id, chatResponse, w.toolCallSent) - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 { - w.toolCallSent = true - } - - w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - - if chatResponse.Done { - if w.streamOptions != nil && w.streamOptions.IncludeUsage { - u := toUsage(chatResponse) - c.Usage = &u - c.Choices = []ChunkChoice{} - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - } - _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) - if err != nil { - return 0, err - } - } - - return len(data), nil - } - - // chat completion - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *ChatWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *CompleteWriter) writeResponse(data []byte) (int, error) { - var generateResponse api.GenerateResponse - err := json.Unmarshal(data, &generateResponse) - if err != nil { - return 0, err - } - - // completion chunk - if w.stream { - c := toCompleteChunk(w.id, generateResponse) - if w.streamOptions != nil && w.streamOptions.IncludeUsage { - c.Usage = &Usage{} - } - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - - if generateResponse.Done { - if w.streamOptions != nil && w.streamOptions.IncludeUsage { - u := toUsageGenerate(generateResponse) - c.Usage = &u - c.Choices = []CompleteChunkChoice{} - d, err := json.Marshal(c) - if err != nil { - return 0, err - } - _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) - if err != nil { - return 0, err - } - } - _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) - if err != nil { - return 0, err - } - } - - return len(data), nil - } - - // completion - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *CompleteWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *ListWriter) writeResponse(data []byte) (int, error) { - var listResponse api.ListResponse - err := json.Unmarshal(data, &listResponse) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *ListWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { - var showResponse api.ShowResponse - err := json.Unmarshal(data, &showResponse) - if err != nil { - return 0, err - } - - // retrieve completion - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *RetrieveWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func (w *EmbedWriter) writeResponse(data []byte) (int, error) { - var embedResponse api.EmbedResponse - err := json.Unmarshal(data, &embedResponse) - if err != nil { - return 0, err - } - - w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) - if err != nil { - return 0, err - } - - return len(data), nil -} - -func (w *EmbedWriter) Write(data []byte) (int, error) { - code := w.ResponseWriter.Status() - if code != http.StatusOK { - return w.writeError(data) - } - - return w.writeResponse(data) -} - -func ListMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - w := &ListWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - } - - c.Writer = w - - c.Next() - } -} - -func RetrieveMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - // response writer - w := &RetrieveWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - model: c.Param("model"), - } - - c.Writer = w - - c.Next() - } -} - -func CompletionsMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var req CompletionRequest - err := c.ShouldBindJSON(&req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - var b bytes.Buffer - genReq, err := fromCompleteRequest(req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if err := json.NewEncoder(&b).Encode(genReq); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - w := &CompleteWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), - streamOptions: req.StreamOptions, - } - - c.Writer = w - c.Next() - } -} - -func EmbeddingsMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var req EmbedRequest - err := c.ShouldBindJSON(&req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if req.Input == "" { - req.Input = []string{""} - } - - if req.Input == nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) - return - } - - if v, ok := req.Input.([]any); ok && len(v) == 0 { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) - return - } - - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - w := &EmbedWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - model: req.Model, - } - - c.Writer = w - - c.Next() - } -} - -func ChatMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var req ChatCompletionRequest - err := c.ShouldBindJSON(&req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if len(req.Messages) == 0 { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) - return - } - - var b bytes.Buffer - - chatReq, err := fromChatRequest(req) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) - return - } - - if err := json.NewEncoder(&b).Encode(chatReq); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) - return - } - - c.Request.Body = io.NopCloser(&b) - - w := &ChatWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), - streamOptions: req.StreamOptions, - } - - c.Writer = w - - c.Next() - } -} diff --git a/openai/openai_test.go b/openai/openai_test.go index 0d7f016ba..0f1a877f4 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -1,19 +1,8 @@ package openai import ( - "bytes" "encoding/base64" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "reflect" - "strings" "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) @@ -23,905 +12,139 @@ const ( image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) -var ( - False = false - True = true -) +func TestFromChatRequest_Basic(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } -func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { - return func(c *gin.Context) { - bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - err := json.Unmarshal(bodyBytes, capturedRequest) - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") - } - c.Next() + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" { + t.Errorf("unexpected message: %+v", result.Messages[0]) } } -func TestChatMiddleware(t *testing.T) { - type testCase struct { - name string - body string - req api.ChatRequest - err ErrorResponse - } +func TestFromChatRequest_WithImage(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(image) - var capturedRequest *api.ChatRequest - - testCases := []testCase{ - { - name: "chat handler", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "Hello"} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + { + Role: "user", + Content: []any{ + map[string]any{"type": "text", "text": "Hello"}, + map[string]any{ + "type": "image_url", + "image_url": map[string]any{"url": prefix + image}, }, }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with options", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "Hello"} - ], - "stream": true, - "max_tokens": 999, - "seed": 123, - "stop": ["\n", "stop"], - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - "response_format": {"type": "json_object"} - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", - }, - }, - Options: map[string]any{ - "num_predict": 999.0, // float because JSON doesn't distinguish between float and int - "seed": 123.0, - "stop": []any{"\n", "stop"}, - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - }, - Format: json.RawMessage(`"json"`), - Stream: &True, - }, - }, - { - name: "chat handler with streaming usage", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "Hello"} - ], - "stream": true, - "stream_options": {"include_usage": true}, - "max_tokens": 999, - "seed": 123, - "stop": ["\n", "stop"], - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - "response_format": {"type": "json_object"} - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", - }, - }, - Options: map[string]any{ - "num_predict": 999.0, // float because JSON doesn't distinguish between float and int - "seed": 123.0, - "stop": []any{"\n", "stop"}, - "temperature": 3.0, - "frequency_penalty": 4.0, - "presence_penalty": 5.0, - "top_p": 6.0, - }, - Format: json.RawMessage(`"json"`), - Stream: &True, - }, - }, - { - name: "chat handler with image content", - body: `{ - "model": "test-model", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Hello" - }, - { - "type": "image_url", - "image_url": { - "url": "` + prefix + image + `" - } - } - ] - } - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "Hello", - }, - { - Role: "user", - Images: []api.ImageData{ - func() []byte { - img, _ := base64.StdEncoding.DecodeString(image) - return img - }(), - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools and content", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - Content: "Let's see what the weather is like in Paris", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools and empty content", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with tools and thinking content", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - Thinking: "Let's see what the weather is like in Paris", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "tool response with call ID", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, - {"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - { - Role: "tool", - Content: "The weather in Paris is 20 degrees Celsius", - ToolName: "get_current_weather", - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "tool response with name", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris Today?"}, - {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}, - {"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"} - ] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris Today?", - }, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: map[string]any{ - "location": "Paris, France", - "format": "celsius", - }, - }, - }, - }, - }, - { - Role: "tool", - Content: "The weather in Paris is 20 degrees Celsius", - ToolName: "get_current_weather", - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, - { - name: "chat handler with streaming tools", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": "What's the weather like in Paris?"} - ], - "stream": true, - "tools": [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "required": ["location"], - "properties": { - "location": { - "type": "string", - "description": "The city and state" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - } - } - } - }] - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{ - { - Role: "user", - Content: "What's the weather like in Paris?", - }, - }, - Tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get the current weather", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]api.ToolProperty `json:"properties"` - }{ - Type: "object", - Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ - "location": { - Type: api.PropertyType{"string"}, - Description: "The city and state", - }, - "unit": { - Type: api.PropertyType{"string"}, - Enum: []any{"celsius", "fahrenheit"}, - }, - }, - }, - }, - }, - }, - Options: map[string]any{ - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &True, - }, - }, - { - name: "chat handler error forwarding", - body: `{ - "model": "test-model", - "messages": [ - {"role": "user", "content": 2} - ] - }`, - err: ErrorResponse{ - Error: Error{ - Message: "invalid message content type: float64", - Type: "invalid_request_error", - }, }, }, } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/chat", endpoint) + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") + if result.Messages[0].Content != "Hello" { + t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content) + } - defer func() { capturedRequest = nil }() + if len(result.Messages[1].Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images)) + } - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) - } - return - } - if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { - t.Fatalf("requests did not match: %+v", diff) - } - if diff := cmp.Diff(tc.err, errResp); diff != "" { - t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) - } - }) + if string(result.Messages[1].Images[0]) != string(imgData) { + t.Error("image data mismatch") } } -func TestCompletionsMiddleware(t *testing.T) { - type testCase struct { - name string - body string - req api.GenerateRequest - err ErrorResponse +func TestFromCompleteRequest_Basic(t *testing.T) { + temp := float32(0.8) + req := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: &temp, } - var capturedRequest *api.GenerateRequest - - testCases := []testCase{ - { - name: "completions handler", - body: `{ - "model": "test-model", - "prompt": "Hello", - "temperature": 0.8, - "stop": ["\n", "stop"], - "suffix": "suffix" - }`, - req: api.GenerateRequest{ - Model: "test-model", - Prompt: "Hello", - Options: map[string]any{ - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 0.8, - "top_p": 1.0, - "stop": []any{"\n", "stop"}, - }, - Suffix: "suffix", - Stream: &False, - }, - }, - { - name: "completions handler stream", - body: `{ - "model": "test-model", - "prompt": "Hello", - "stream": true, - "temperature": 0.8, - "stop": ["\n", "stop"], - "suffix": "suffix" - }`, - req: api.GenerateRequest{ - Model: "test-model", - Prompt: "Hello", - Options: map[string]any{ - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 0.8, - "top_p": 1.0, - "stop": []any{"\n", "stop"}, - }, - Suffix: "suffix", - Stream: &True, - }, - }, - { - name: "completions handler stream with usage", - body: `{ - "model": "test-model", - "prompt": "Hello", - "stream": true, - "stream_options": {"include_usage": true}, - "temperature": 0.8, - "stop": ["\n", "stop"], - "suffix": "suffix" - }`, - req: api.GenerateRequest{ - Model: "test-model", - Prompt: "Hello", - Options: map[string]any{ - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 0.8, - "top_p": 1.0, - "stop": []any{"\n", "stop"}, - }, - Suffix: "suffix", - Stream: &True, - }, - }, - { - name: "completions handler error forwarding", - body: `{ - "model": "test-model", - "prompt": "Hello", - "temperature": null, - "stop": [1, 2], - "suffix": "suffix" - }`, - err: ErrorResponse{ - Error: Error{ - Message: "invalid type for 'stop' field: float64", - Type: "invalid_request_error", - }, - }, - }, + result, err := FromCompleteRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/generate", endpoint) + if result.Prompt != "Hello" { + t.Errorf("expected prompt 'Hello', got %q", result.Prompt) + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) - } - } - - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") - } - - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") - } - - capturedRequest = nil - }) + if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 { + t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"]) } } -func TestEmbeddingsMiddleware(t *testing.T) { - type testCase struct { - name string - body string - req api.EmbedRequest - err ErrorResponse - } - - var capturedRequest *api.EmbedRequest - - testCases := []testCase{ - { - name: "embed handler single input", - body: `{ - "input": "Hello", - "model": "test-model" - }`, - req: api.EmbedRequest{ - Input: "Hello", - Model: "test-model", - }, - }, - { - name: "embed handler batch input", - body: `{ - "input": ["Hello", "World"], - "model": "test-model" - }`, - req: api.EmbedRequest{ - Input: []any{"Hello", "World"}, - Model: "test-model", - }, - }, - { - name: "embed handler error forwarding", - body: `{ - "model": "test-model" - }`, - err: ErrorResponse{ - Error: Error{ - Message: "invalid input", - Type: "invalid_request_error", - }, - }, +func TestToUsage(t *testing.T) { + resp := api.ChatResponse{ + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 20, }, } - endpoint := func(c *gin.Context) { - c.Status(http.StatusOK) + usage := ToUsage(resp) + + if usage.PromptTokens != 10 { + t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens) } - gin.SetMode(gin.TestMode) - router := gin.New() - router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) - router.Handle(http.MethodPost, "/api/embed", endpoint) + if usage.CompletionTokens != 20 { + t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens) + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) - req.Header.Set("Content-Type", "application/json") - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var errResp ErrorResponse - if resp.Code != http.StatusOK { - if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { - t.Fatal(err) - } - } - - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") - } - - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") - } - - capturedRequest = nil - }) + if usage.TotalTokens != 30 { + t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens) } } -func TestListMiddleware(t *testing.T) { - type testCase struct { - name string - endpoint func(c *gin.Context) - resp string +func TestNewError(t *testing.T) { + tests := []struct { + code int + want string + }{ + {400, "invalid_request_error"}, + {404, "not_found_error"}, + {500, "api_error"}, } - testCases := []testCase{ - { - name: "list handler", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ListResponse{ - Models: []api.ListModelResponse{ - { - Name: "test-model", - ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), - }, - }, - }) - }, - resp: `{ - "object": "list", - "data": [ - { - "id": "test-model", - "object": "model", - "created": 1686935002, - "owned_by": "library" - } - ] - }`, - }, - { - name: "list handler empty output", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ListResponse{}) - }, - resp: `{ - "object": "list", - "data": null - }`, - }, - } - - gin.SetMode(gin.TestMode) - - for _, tc := range testCases { - router := gin.New() - router.Use(ListMiddleware()) - router.Handle(http.MethodGet, "/api/tags", tc.endpoint) - req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var expected, actual map[string]any - err := json.Unmarshal([]byte(tc.resp), &expected) - if err != nil { - t.Fatalf("failed to unmarshal expected response: %v", err) + for _, tt := range tests { + result := NewError(tt.code, "test message") + if result.Error.Type != tt.want { + t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want) } - - err = json.Unmarshal(resp.Body.Bytes(), &actual) - if err != nil { - t.Fatalf("failed to unmarshal actual response: %v", err) - } - - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) - } - } -} - -func TestRetrieveMiddleware(t *testing.T) { - type testCase struct { - name string - endpoint func(c *gin.Context) - resp string - } - - testCases := []testCase{ - { - name: "retrieve handler", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ShowResponse{ - ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), - }) - }, - resp: `{ - "id":"test-model", - "object":"model", - "created":1686935002, - "owned_by":"library"} - `, - }, - { - name: "retrieve handler error forwarding", - endpoint: func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) - }, - resp: `{ - "error": { - "code": null, - "message": "model not found", - "param": null, - "type": "api_error" - } - }`, - }, - } - - gin.SetMode(gin.TestMode) - - for _, tc := range testCases { - router := gin.New() - router.Use(RetrieveMiddleware()) - router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) - req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - - var expected, actual map[string]any - err := json.Unmarshal([]byte(tc.resp), &expected) - if err != nil { - t.Fatalf("failed to unmarshal expected response: %v", err) - } - - err = json.Unmarshal(resp.Body.Bytes(), &actual) - if err != nil { - t.Fatalf("failed to unmarshal actual response: %v", err) - } - - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + if result.Error.Message != "test message" { + t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message") } } } diff --git a/server/routes.go b/server/routes.go index 7e0ba1c60..c7052f1b7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -37,8 +37,8 @@ import ( "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/middleware" "github.com/ollama/ollama/model/parsers" - "github.com/ollama/ollama/openai" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -1449,11 +1449,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/embeddings", s.EmbeddingsHandler) // Inference (OpenAI compatibility) - r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) - r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) - r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) - r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) - r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) + r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler) + r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler) + r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler) + r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler) if rc != nil { // wrap old with new