mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			520 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			520 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
| package openai
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 
 | |
| 	"github.com/ollama/ollama/api"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	prefix = `data:image/jpeg;base64,`
 | |
| 	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
 | |
| )
 | |
| 
 | |
| var False = false
 | |
| 
 | |
| func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		bodyBytes, _ := io.ReadAll(c.Request.Body)
 | |
| 		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 | |
| 		err := json.Unmarshal(bodyBytes, capturedRequest)
 | |
| 		if err != nil {
 | |
| 			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestChatMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		name string
 | |
| 		body string
 | |
| 		req  api.ChatRequest
 | |
| 		err  ErrorResponse
 | |
| 	}
 | |
| 
 | |
| 	var capturedRequest *api.ChatRequest
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			name: "chat handler",
 | |
| 			body: `{
 | |
| 				"model": "test-model",
 | |
| 				"messages": [
 | |
| 					{"role": "user", "content": "Hello"}
 | |
| 				]
 | |
| 			}`,
 | |
| 			req: api.ChatRequest{
 | |
| 				Model: "test-model",
 | |
| 				Messages: []api.Message{
 | |
| 					{
 | |
| 						Role:    "user",
 | |
| 						Content: "Hello",
 | |
| 					},
 | |
| 				},
 | |
| 				Options: map[string]any{
 | |
| 					"temperature": 1.0,
 | |
| 					"top_p":       1.0,
 | |
| 				},
 | |
| 				Stream: &False,
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "chat handler with image content",
 | |
| 			body: `{
 | |
| 				"model": "test-model",
 | |
| 				"messages": [
 | |
| 					{
 | |
| 						"role": "user",
 | |
| 						"content": [
 | |
| 							{
 | |
| 								"type": "text",
 | |
| 								"text": "Hello"
 | |
| 							},
 | |
| 							{
 | |
| 								"type": "image_url",
 | |
| 								"image_url": {
 | |
| 									"url": "` + prefix + image + `"
 | |
| 								}
 | |
| 							}
 | |
| 						]
 | |
| 					}
 | |
| 				]
 | |
| 			}`,
 | |
| 			req: api.ChatRequest{
 | |
| 				Model: "test-model",
 | |
| 				Messages: []api.Message{
 | |
| 					{
 | |
| 						Role:    "user",
 | |
| 						Content: "Hello",
 | |
| 					},
 | |
| 					{
 | |
| 						Role: "user",
 | |
| 						Images: []api.ImageData{
 | |
| 							func() []byte {
 | |
| 								img, _ := base64.StdEncoding.DecodeString(image)
 | |
| 								return img
 | |
| 							}(),
 | |
| 						},
 | |
| 					},
 | |
| 				},
 | |
| 				Options: map[string]any{
 | |
| 					"temperature": 1.0,
 | |
| 					"top_p":       1.0,
 | |
| 				},
 | |
| 				Stream: &False,
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "chat handler with tools",
 | |
| 			body: `{
 | |
| 				"model": "test-model",
 | |
| 				"messages": [
 | |
| 					{"role": "user", "content": "What's the weather like in Paris Today?"},
 | |
| 					{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
 | |
| 				]
 | |
| 			}`,
 | |
| 			req: api.ChatRequest{
 | |
| 				Model: "test-model",
 | |
| 				Messages: []api.Message{
 | |
| 					{
 | |
| 						Role:    "user",
 | |
| 						Content: "What's the weather like in Paris Today?",
 | |
| 					},
 | |
| 					{
 | |
| 						Role: "assistant",
 | |
| 						ToolCalls: []api.ToolCall{
 | |
| 							{
 | |
| 								Function: api.ToolCallFunction{
 | |
| 									Name: "get_current_weather",
 | |
| 									Arguments: map[string]interface{}{
 | |
| 										"location": "Paris, France",
 | |
| 										"format":   "celsius",
 | |
| 									},
 | |
| 								},
 | |
| 							},
 | |
| 						},
 | |
| 					},
 | |
| 				},
 | |
| 				Options: map[string]any{
 | |
| 					"temperature": 1.0,
 | |
| 					"top_p":       1.0,
 | |
| 				},
 | |
| 				Stream: &False,
 | |
| 			},
 | |
| 		},
 | |
| 
 | |
| 		{
 | |
| 			name: "chat handler error forwarding",
 | |
| 			body: `{
 | |
| 				"model": "test-model",
 | |
| 				"messages": [
 | |
| 					{"role": "user", "content": 2}
 | |
| 				]
 | |
| 			}`,
 | |
| 			err: ErrorResponse{
 | |
| 				Error: Error{
 | |
| 					Message: "invalid message content type: float64",
 | |
| 					Type:    "invalid_request_error",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	endpoint := func(c *gin.Context) {
 | |
| 		c.Status(http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
 | |
| 	router.Handle(http.MethodPost, "/api/chat", endpoint)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
 | |
| 			req.Header.Set("Content-Type", "application/json")
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			var errResp ErrorResponse
 | |
| 			if resp.Code != http.StatusOK {
 | |
| 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 			}
 | |
| 			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
 | |
| 				t.Fatal("requests did not match")
 | |
| 			}
 | |
| 
 | |
| 			if !reflect.DeepEqual(tc.err, errResp) {
 | |
| 				t.Fatal("errors did not match")
 | |
| 			}
 | |
| 			capturedRequest = nil
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestCompletionsMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		name string
 | |
| 		body string
 | |
| 		req  api.GenerateRequest
 | |
| 		err  ErrorResponse
 | |
| 	}
 | |
| 
 | |
| 	var capturedRequest *api.GenerateRequest
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			name: "completions handler",
 | |
| 			body: `{
 | |
| 				"model": "test-model",
 | |
| 				"prompt": "Hello",
 | |
| 				"temperature": 0.8,
 | |
| 				"stop": ["\n", "stop"],
 | |
| 				"suffix": "suffix"
 | |
| 			}`,
 | |
| 			req: api.GenerateRequest{
 | |
| 				Model:  "test-model",
 | |
| 				Prompt: "Hello",
 | |
| 				Options: map[string]any{
 | |
| 					"frequency_penalty": 0.0,
 | |
| 					"presence_penalty":  0.0,
 | |
| 					"temperature":       1.6,
 | |
| 					"top_p":             1.0,
 | |
| 					"stop":              []any{"\n", "stop"},
 | |
| 				},
 | |
| 				Suffix: "suffix",
 | |
| 				Stream: &False,
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "completions handler error forwarding",
 | |
| 			body: `{
 | |
| 				"model": "test-model",
 | |
| 				"prompt": "Hello",
 | |
| 				"temperature": null,
 | |
| 				"stop": [1, 2],
 | |
| 				"suffix": "suffix"
 | |
| 			}`,
 | |
| 			err: ErrorResponse{
 | |
| 				Error: Error{
 | |
| 					Message: "invalid type for 'stop' field: float64",
 | |
| 					Type:    "invalid_request_error",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	endpoint := func(c *gin.Context) {
 | |
| 		c.Status(http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
 | |
| 	router.Handle(http.MethodPost, "/api/generate", endpoint)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
 | |
| 			req.Header.Set("Content-Type", "application/json")
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			var errResp ErrorResponse
 | |
| 			if resp.Code != http.StatusOK {
 | |
| 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
 | |
| 				t.Fatal("requests did not match")
 | |
| 			}
 | |
| 
 | |
| 			if !reflect.DeepEqual(tc.err, errResp) {
 | |
| 				t.Fatal("errors did not match")
 | |
| 			}
 | |
| 
 | |
| 			capturedRequest = nil
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestEmbeddingsMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		name string
 | |
| 		body string
 | |
| 		req  api.EmbedRequest
 | |
| 		err  ErrorResponse
 | |
| 	}
 | |
| 
 | |
| 	var capturedRequest *api.EmbedRequest
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			name: "embed handler single input",
 | |
| 			body: `{
 | |
| 				"input": "Hello",
 | |
| 				"model": "test-model"
 | |
| 			}`,
 | |
| 			req: api.EmbedRequest{
 | |
| 				Input: "Hello",
 | |
| 				Model: "test-model",
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "embed handler batch input",
 | |
| 			body: `{
 | |
| 				"input": ["Hello", "World"],
 | |
| 				"model": "test-model"
 | |
| 			}`,
 | |
| 			req: api.EmbedRequest{
 | |
| 				Input: []any{"Hello", "World"},
 | |
| 				Model: "test-model",
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "embed handler error forwarding",
 | |
| 			body: `{
 | |
| 				"model": "test-model"
 | |
| 			}`,
 | |
| 			err: ErrorResponse{
 | |
| 				Error: Error{
 | |
| 					Message: "invalid input",
 | |
| 					Type:    "invalid_request_error",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	endpoint := func(c *gin.Context) {
 | |
| 		c.Status(http.StatusOK)
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 	router := gin.New()
 | |
| 	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
 | |
| 	router.Handle(http.MethodPost, "/api/embed", endpoint)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
 | |
| 			req.Header.Set("Content-Type", "application/json")
 | |
| 
 | |
| 			resp := httptest.NewRecorder()
 | |
| 			router.ServeHTTP(resp, req)
 | |
| 
 | |
| 			var errResp ErrorResponse
 | |
| 			if resp.Code != http.StatusOK {
 | |
| 				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
 | |
| 				t.Fatal("requests did not match")
 | |
| 			}
 | |
| 
 | |
| 			if !reflect.DeepEqual(tc.err, errResp) {
 | |
| 				t.Fatal("errors did not match")
 | |
| 			}
 | |
| 
 | |
| 			capturedRequest = nil
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestListMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		name     string
 | |
| 		endpoint func(c *gin.Context)
 | |
| 		resp     string
 | |
| 	}
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			name: "list handler",
 | |
| 			endpoint: func(c *gin.Context) {
 | |
| 				c.JSON(http.StatusOK, api.ListResponse{
 | |
| 					Models: []api.ListModelResponse{
 | |
| 						{
 | |
| 							Name:       "test-model",
 | |
| 							ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
 | |
| 						},
 | |
| 					},
 | |
| 				})
 | |
| 			},
 | |
| 			resp: `{
 | |
| 				"object": "list",
 | |
| 				"data": [
 | |
| 					{
 | |
| 						"id": "test-model",
 | |
| 						"object": "model",
 | |
| 						"created": 1686935002,
 | |
| 						"owned_by": "library"
 | |
| 					}
 | |
| 				]
 | |
| 			}`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "list handler empty output",
 | |
| 			endpoint: func(c *gin.Context) {
 | |
| 				c.JSON(http.StatusOK, api.ListResponse{})
 | |
| 			},
 | |
| 			resp: `{
 | |
| 				"object": "list",
 | |
| 				"data": null
 | |
| 			}`,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		router := gin.New()
 | |
| 		router.Use(ListMiddleware())
 | |
| 		router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
 | |
| 		req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
 | |
| 
 | |
| 		resp := httptest.NewRecorder()
 | |
| 		router.ServeHTTP(resp, req)
 | |
| 
 | |
| 		var expected, actual map[string]any
 | |
| 		err := json.Unmarshal([]byte(tc.resp), &expected)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to unmarshal expected response: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		err = json.Unmarshal(resp.Body.Bytes(), &actual)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to unmarshal actual response: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		if !reflect.DeepEqual(expected, actual) {
 | |
| 			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestRetrieveMiddleware(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		name     string
 | |
| 		endpoint func(c *gin.Context)
 | |
| 		resp     string
 | |
| 	}
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			name: "retrieve handler",
 | |
| 			endpoint: func(c *gin.Context) {
 | |
| 				c.JSON(http.StatusOK, api.ShowResponse{
 | |
| 					ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
 | |
| 				})
 | |
| 			},
 | |
| 			resp: `{
 | |
| 				"id":"test-model",
 | |
| 				"object":"model",
 | |
| 				"created":1686935002,
 | |
| 				"owned_by":"library"}
 | |
| 			`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "retrieve handler error forwarding",
 | |
| 			endpoint: func(c *gin.Context) {
 | |
| 				c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
 | |
| 			},
 | |
| 			resp: `{
 | |
| 				"error": {
 | |
| 				  "code": null,
 | |
| 				  "message": "model not found",
 | |
| 				  "param": null,
 | |
| 				  "type": "api_error"
 | |
| 				}
 | |
| 			}`,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		router := gin.New()
 | |
| 		router.Use(RetrieveMiddleware())
 | |
| 		router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
 | |
| 		req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
 | |
| 
 | |
| 		resp := httptest.NewRecorder()
 | |
| 		router.ServeHTTP(resp, req)
 | |
| 
 | |
| 		var expected, actual map[string]any
 | |
| 		err := json.Unmarshal([]byte(tc.resp), &expected)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to unmarshal expected response: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		err = json.Unmarshal(resp.Body.Bytes(), &actual)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to unmarshal actual response: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		if !reflect.DeepEqual(expected, actual) {
 | |
| 			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
 | |
| 		}
 | |
| 	}
 | |
| }
 |