mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			497 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			497 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
| package openai
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/ollama/ollama/api"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| )
 | |
| 
 | |
| const prefix = `data:image/jpeg;base64,`
 | |
| const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
 | |
| const imageURL = prefix + image
 | |
| 
 | |
| func prepareRequest(req *http.Request, body any) {
 | |
| 	bodyBytes, _ := json.Marshal(body)
 | |
| 	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 | |
| 	req.Header.Set("Content-Type", "application/json")
 | |
| }
 | |
| 
 | |
| func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		bodyBytes, _ := io.ReadAll(c.Request.Body)
 | |
| 		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 | |
| 		err := json.Unmarshal(bodyBytes, capturedRequest)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestChatMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		Name     string
 | |
| 		Setup    func(t *testing.T, req *http.Request)
 | |
| 		Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
 | |
| 	}
 | |
| 
 | |
| 	var capturedRequest *api.ChatRequest
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			Name: "chat handler",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := ChatCompletionRequest{
 | |
| 					Model:    "test-model",
 | |
| 					Messages: []Message{{Role: "user", Content: "Hello"}},
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if resp.Code != http.StatusOK {
 | |
| 					t.Fatalf("expected 200, got %d", resp.Code)
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[0].Role != "user" {
 | |
| 					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[0].Content != "Hello" {
 | |
| 					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "chat handler with image content",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := ChatCompletionRequest{
 | |
| 					Model: "test-model",
 | |
| 					Messages: []Message{
 | |
| 						{
 | |
| 							Role: "user", Content: []map[string]any{
 | |
| 								{"type": "text", "text": "Hello"},
 | |
| 								{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
 | |
| 							},
 | |
| 						},
 | |
| 					},
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if resp.Code != http.StatusOK {
 | |
| 					t.Fatalf("expected 200, got %d", resp.Code)
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[0].Role != "user" {
 | |
| 					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[0].Content != "Hello" {
 | |
| 					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
 | |
| 				}
 | |
| 
 | |
| 				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
 | |
| 
 | |
| 				if req.Messages[1].Role != "user" {
 | |
| 					t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
 | |
| 				}
 | |
| 
 | |
| 				if !bytes.Equal(req.Messages[1].Images[0], img) {
 | |
| 					t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "chat handler with tools",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := ChatCompletionRequest{
 | |
| 					Model: "test-model",
 | |
| 					Messages: []Message{
 | |
| 						{Role: "user", Content: "What's the weather like in Paris Today?"},
 | |
| 						{Role: "assistant", ToolCalls: []ToolCall{{
 | |
| 							ID:   "id",
 | |
| 							Type: "function",
 | |
| 							Function: struct {
 | |
| 								Name      string `json:"name"`
 | |
| 								Arguments string `json:"arguments"`
 | |
| 							}{
 | |
| 								Name:      "get_current_weather",
 | |
| 								Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
 | |
| 							},
 | |
| 						}}},
 | |
| 					},
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if resp.Code != 200 {
 | |
| 					t.Fatalf("expected 200, got %d", resp.Code)
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[0].Content != "What's the weather like in Paris Today?" {
 | |
| 					t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
 | |
| 					t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
 | |
| 				}
 | |
| 
 | |
| 				if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
 | |
| 					t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "chat handler error forwarding",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := ChatCompletionRequest{
 | |
| 					Model:    "test-model",
 | |
| 					Messages: []Message{{Role: "user", Content: 2}},
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if resp.Code != http.StatusBadRequest {
 | |
| 					t.Fatalf("expected 400, got %d", resp.Code)
 | |
| 				}
 | |
| 
 | |
| 				if !strings.Contains(resp.Body.String(), "invalid message content type") {
 | |
| 					t.Fatalf("error was not forwarded")
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	endpoint := func(c *gin.Context) {
 | |
| 		c.Status(http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
 | |
| 	router.Handle(http.MethodPost, "/api/chat", endpoint)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.Name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
 | |
| 
 | |
| 			tc.Setup(t, req)
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			tc.Expected(t, capturedRequest, resp)
 | |
| 
 | |
| 			capturedRequest = nil
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestCompletionsMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		Name     string
 | |
| 		Setup    func(t *testing.T, req *http.Request)
 | |
| 		Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
 | |
| 	}
 | |
| 
 | |
| 	var capturedRequest *api.GenerateRequest
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			Name: "completions handler",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				temp := float32(0.8)
 | |
| 				body := CompletionRequest{
 | |
| 					Model:       "test-model",
 | |
| 					Prompt:      "Hello",
 | |
| 					Temperature: &temp,
 | |
| 					Stop:        []string{"\n", "stop"},
 | |
| 					Suffix:      "suffix",
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if req.Prompt != "Hello" {
 | |
| 					t.Fatalf("expected 'Hello', got %s", req.Prompt)
 | |
| 				}
 | |
| 
 | |
| 				if req.Options["temperature"] != 1.6 {
 | |
| 					t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
 | |
| 				}
 | |
| 
 | |
| 				stopTokens, ok := req.Options["stop"].([]any)
 | |
| 
 | |
| 				if !ok {
 | |
| 					t.Fatalf("expected stop tokens to be a list")
 | |
| 				}
 | |
| 
 | |
| 				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
 | |
| 					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
 | |
| 				}
 | |
| 
 | |
| 				if req.Suffix != "suffix" {
 | |
| 					t.Fatalf("expected 'suffix', got %s", req.Suffix)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "completions handler error forwarding",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := CompletionRequest{
 | |
| 					Model:       "test-model",
 | |
| 					Prompt:      "Hello",
 | |
| 					Temperature: nil,
 | |
| 					Stop:        []int{1, 2},
 | |
| 					Suffix:      "suffix",
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if resp.Code != http.StatusBadRequest {
 | |
| 					t.Fatalf("expected 400, got %d", resp.Code)
 | |
| 				}
 | |
| 
 | |
| 				if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
 | |
| 					t.Fatalf("error was not forwarded")
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	endpoint := func(c *gin.Context) {
 | |
| 		c.Status(http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
 | |
| 	router.Handle(http.MethodPost, "/api/generate", endpoint)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.Name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
 | |
| 
 | |
| 			tc.Setup(t, req)
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			tc.Expected(t, capturedRequest, resp)
 | |
| 
 | |
| 			capturedRequest = nil
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestEmbeddingsMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		Name     string
 | |
| 		Setup    func(t *testing.T, req *http.Request)
 | |
| 		Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
 | |
| 	}
 | |
| 
 | |
| 	var capturedRequest *api.EmbedRequest
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			Name: "embed handler single input",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := EmbedRequest{
 | |
| 					Input: "Hello",
 | |
| 					Model: "test-model",
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if req.Input != "Hello" {
 | |
| 					t.Fatalf("expected 'Hello', got %s", req.Input)
 | |
| 				}
 | |
| 
 | |
| 				if req.Model != "test-model" {
 | |
| 					t.Fatalf("expected 'test-model', got %s", req.Model)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "embed handler batch input",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := EmbedRequest{
 | |
| 					Input: []string{"Hello", "World"},
 | |
| 					Model: "test-model",
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
 | |
| 				input, ok := req.Input.([]any)
 | |
| 
 | |
| 				if !ok {
 | |
| 					t.Fatalf("expected input to be a list")
 | |
| 				}
 | |
| 
 | |
| 				if input[0].(string) != "Hello" {
 | |
| 					t.Fatalf("expected 'Hello', got %s", input[0])
 | |
| 				}
 | |
| 
 | |
| 				if input[1].(string) != "World" {
 | |
| 					t.Fatalf("expected 'World', got %s", input[1])
 | |
| 				}
 | |
| 
 | |
| 				if req.Model != "test-model" {
 | |
| 					t.Fatalf("expected 'test-model', got %s", req.Model)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "embed handler error forwarding",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				body := EmbedRequest{
 | |
| 					Model: "test-model",
 | |
| 				}
 | |
| 				prepareRequest(req, body)
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
 | |
| 				if resp.Code != http.StatusBadRequest {
 | |
| 					t.Fatalf("expected 400, got %d", resp.Code)
 | |
| 				}
 | |
| 
 | |
| 				if !strings.Contains(resp.Body.String(), "invalid input") {
 | |
| 					t.Fatalf("error was not forwarded")
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	endpoint := func(c *gin.Context) {
 | |
| 		c.Status(http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
 | |
| 	router.Handle(http.MethodPost, "/api/embed", endpoint)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.Name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
 | |
| 
 | |
| 			tc.Setup(t, req)
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			tc.Expected(t, capturedRequest, resp)
 | |
| 
 | |
| 			capturedRequest = nil
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestMiddlewareResponses(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		Name     string
 | |
| 		Method   string
 | |
| 		Path     string
 | |
| 		TestPath string
 | |
| 		Handler  func() gin.HandlerFunc
 | |
| 		Endpoint func(c *gin.Context)
 | |
| 		Setup    func(t *testing.T, req *http.Request)
 | |
| 		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
 | |
| 	}
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			Name:     "list handler",
 | |
| 			Method:   http.MethodGet,
 | |
| 			Path:     "/api/tags",
 | |
| 			TestPath: "/api/tags",
 | |
| 			Handler:  ListMiddleware,
 | |
| 			Endpoint: func(c *gin.Context) {
 | |
| 				c.JSON(http.StatusOK, api.ListResponse{
 | |
| 					Models: []api.ListModelResponse{
 | |
| 						{
 | |
| 							Name: "Test Model",
 | |
| 						},
 | |
| 					},
 | |
| 				})
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
 | |
| 				var listResp ListCompletion
 | |
| 				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 
 | |
| 				if listResp.Object != "list" {
 | |
| 					t.Fatalf("expected list, got %s", listResp.Object)
 | |
| 				}
 | |
| 
 | |
| 				if len(listResp.Data) != 1 {
 | |
| 					t.Fatalf("expected 1, got %d", len(listResp.Data))
 | |
| 				}
 | |
| 
 | |
| 				if listResp.Data[0].Id != "Test Model" {
 | |
| 					t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:     "retrieve model",
 | |
| 			Method:   http.MethodGet,
 | |
| 			Path:     "/api/show/:model",
 | |
| 			TestPath: "/api/show/test-model",
 | |
| 			Handler:  RetrieveMiddleware,
 | |
| 			Endpoint: func(c *gin.Context) {
 | |
| 				c.JSON(http.StatusOK, api.ShowResponse{
 | |
| 					ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
 | |
| 				})
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
 | |
| 				var retrieveResp Model
 | |
| 				if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 
 | |
| 				if retrieveResp.Object != "model" {
 | |
| 					t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
 | |
| 				}
 | |
| 
 | |
| 				if retrieveResp.Id != "test-model" {
 | |
| 					t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.Name, func(t *testing.T) {
 | |
| 			router = gin.New()
 | |
| 			router.Use(tc.Handler())
 | |
| 			router.Handle(tc.Method, tc.Path, tc.Endpoint)
 | |
| 			req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
 | |
| 
 | |
| 			if tc.Setup != nil {
 | |
| 				tc.Setup(t, req)
 | |
| 			}
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusOK, resp.Code)
 | |
| 
 | |
| 			tc.Expected(t, resp)
 | |
| 		})
 | |
| 	}
 | |
| }
 |