mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			971 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			971 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Go
		
	
	
	
| package server
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"encoding/binary"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"io/fs"
 | |
| 	"math"
 | |
| 	"math/rand/v2"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"reflect"
 | |
| 	"sort"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"unicode"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/google/go-cmp/cmp"
 | |
| 	"github.com/ollama/ollama/api"
 | |
| 	"github.com/ollama/ollama/fs/ggml"
 | |
| 	"github.com/ollama/ollama/openai"
 | |
| 	"github.com/ollama/ollama/server/internal/client/ollama"
 | |
| 	"github.com/ollama/ollama/types/model"
 | |
| 	"github.com/ollama/ollama/version"
 | |
| )
 | |
| 
 | |
| func createTestFile(t *testing.T, name string) (string, string) {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	modelDir := os.Getenv("OLLAMA_MODELS")
 | |
| 	if modelDir == "" {
 | |
| 		t.Fatalf("OLLAMA_MODELS not specified")
 | |
| 	}
 | |
| 
 | |
| 	f, err := os.CreateTemp(t.TempDir(), name)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to create temp file: %v", err)
 | |
| 	}
 | |
| 	defer f.Close()
 | |
| 
 | |
| 	err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to write to file: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = binary.Write(f, binary.LittleEndian, uint32(3))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to write to file: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = binary.Write(f, binary.LittleEndian, uint64(0))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to write to file: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	err = binary.Write(f, binary.LittleEndian, uint64(0))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to write to file: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Calculate sha256 sum of file
 | |
| 	if _, err := f.Seek(0, 0); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	digest, _ := GetSHA256Digest(f)
 | |
| 	if err := f.Close(); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	if err := createLink(f.Name(), filepath.Join(modelDir, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	return f.Name(), digest
 | |
| }
 | |
| 
 | |
| // equalStringSlices checks if two slices of strings are equal.
 | |
| func equalStringSlices(a, b []string) bool {
 | |
| 	if len(a) != len(b) {
 | |
| 		return false
 | |
| 	}
 | |
| 	for i := range a {
 | |
| 		if a[i] != b[i] {
 | |
| 			return false
 | |
| 		}
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| type panicTransport struct{}
 | |
| 
 | |
| func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
 | |
| 	panic("unexpected RoundTrip call")
 | |
| }
 | |
| 
 | |
| var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
 | |
| 
 | |
| func TestRoutes(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		Name     string
 | |
| 		Method   string
 | |
| 		Path     string
 | |
| 		Setup    func(t *testing.T, req *http.Request)
 | |
| 		Expected func(t *testing.T, resp *http.Response)
 | |
| 	}
 | |
| 
 | |
| 	createTestModel := func(t *testing.T, name string) {
 | |
| 		t.Helper()
 | |
| 
 | |
| 		_, digest := createTestFile(t, "ollama-model")
 | |
| 
 | |
| 		fn := func(resp api.ProgressResponse) {
 | |
| 			t.Logf("Status: %s", resp.Status)
 | |
| 		}
 | |
| 
 | |
| 		r := api.CreateRequest{
 | |
| 			Name:  name,
 | |
| 			Files: map[string]string{"test.gguf": digest},
 | |
| 			Parameters: map[string]any{
 | |
| 				"seed":  42,
 | |
| 				"top_p": 0.9,
 | |
| 				"stop":  []string{"foo", "bar"},
 | |
| 			},
 | |
| 		}
 | |
| 
 | |
| 		modelName := model.ParseName(name)
 | |
| 
 | |
| 		baseLayers, err := ggufLayers(digest, fn)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to create model: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		if err := createModel(r, modelName, baseLayers, fn); err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			Name:   "Version Handler",
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/api/version",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json; charset=utf-8" {
 | |
| 					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 				expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
 | |
| 				if string(body) != expectedBody {
 | |
| 					t.Errorf("expected body %s, got %s", expectedBody, string(body))
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Tags Handler (no tags)",
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/api/tags",
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json; charset=utf-8" {
 | |
| 					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var modelList api.ListResponse
 | |
| 
 | |
| 				err = json.Unmarshal(body, &modelList)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if modelList.Models == nil || len(modelList.Models) != 0 {
 | |
| 					t.Errorf("expected empty model list, got %v", modelList.Models)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "openai empty list",
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/v1/models",
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json" {
 | |
| 					t.Errorf("expected content type application/json, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var modelList openai.ListCompletion
 | |
| 				err = json.Unmarshal(body, &modelList)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if modelList.Object != "list" || len(modelList.Data) != 0 {
 | |
| 					t.Errorf("expected empty model list, got %v", modelList.Data)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Tags Handler (yes tags)",
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/api/tags",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				createTestModel(t, "test-model")
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json; charset=utf-8" {
 | |
| 					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if strings.Contains(string(body), "expires_at") {
 | |
| 					t.Errorf("response body should not contain 'expires_at'")
 | |
| 				}
 | |
| 
 | |
| 				var modelList api.ListResponse
 | |
| 				err = json.Unmarshal(body, &modelList)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
 | |
| 					t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Delete Model Handler",
 | |
| 			Method: http.MethodDelete,
 | |
| 			Path:   "/api/delete",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				createTestModel(t, "model_to_delete")
 | |
| 
 | |
| 				deleteReq := api.DeleteRequest{
 | |
| 					Name: "model_to_delete",
 | |
| 				}
 | |
| 				jsonData, err := json.Marshal(deleteReq)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to marshal delete request: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				req.Body = io.NopCloser(bytes.NewReader(jsonData))
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				if resp.StatusCode != http.StatusOK {
 | |
| 					t.Errorf("expected status code 200, got %d", resp.StatusCode)
 | |
| 				}
 | |
| 
 | |
| 				// Verify the model was deleted
 | |
| 				_, err := GetModel("model-to-delete")
 | |
| 				if err == nil || !os.IsNotExist(err) {
 | |
| 					t.Errorf("expected model to be deleted, got error %v", err)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Delete Non-existent Model",
 | |
| 			Method: http.MethodDelete,
 | |
| 			Path:   "/api/delete",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				deleteReq := api.DeleteRequest{
 | |
| 					Name: "non_existent_model",
 | |
| 				}
 | |
| 				jsonData, err := json.Marshal(deleteReq)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to marshal delete request: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				req.Body = io.NopCloser(bytes.NewReader(jsonData))
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				if resp.StatusCode != http.StatusNotFound {
 | |
| 					t.Errorf("expected status code 404, got %d", resp.StatusCode)
 | |
| 				}
 | |
| 
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var errorResp map[string]string
 | |
| 				err = json.Unmarshal(body, &errorResp)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if !strings.Contains(errorResp["error"], "not found") {
 | |
| 					t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "openai list models with tags",
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/v1/models",
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json" {
 | |
| 					t.Errorf("expected content type application/json, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var modelList openai.ListCompletion
 | |
| 				err = json.Unmarshal(body, &modelList)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
 | |
| 					t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Create Model Handler",
 | |
| 			Method: http.MethodPost,
 | |
| 			Path:   "/api/create",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				_, digest := createTestFile(t, "ollama-model")
 | |
| 				stream := false
 | |
| 				createReq := api.CreateRequest{
 | |
| 					Name:   "t-bone",
 | |
| 					Files:  map[string]string{"test.gguf": digest},
 | |
| 					Stream: &stream,
 | |
| 				}
 | |
| 				jsonData, err := json.Marshal(createReq)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to marshal create request: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				req.Body = io.NopCloser(bytes.NewReader(jsonData))
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json" {
 | |
| 					t.Errorf("expected content type application/json, got %s", contentType)
 | |
| 				}
 | |
| 				_, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 				if resp.StatusCode != http.StatusOK { // Updated line
 | |
| 					t.Errorf("expected status code 200, got %d", resp.StatusCode)
 | |
| 				}
 | |
| 
 | |
| 				model, err := GetModel("t-bone")
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to get model: %v", err)
 | |
| 				}
 | |
| 				if model.ShortName != "t-bone:latest" {
 | |
| 					t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Copy Model Handler",
 | |
| 			Method: http.MethodPost,
 | |
| 			Path:   "/api/copy",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				createTestModel(t, "hamshank")
 | |
| 				copyReq := api.CopyRequest{
 | |
| 					Source:      "hamshank",
 | |
| 					Destination: "beefsteak",
 | |
| 				}
 | |
| 				jsonData, err := json.Marshal(copyReq)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to marshal copy request: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				req.Body = io.NopCloser(bytes.NewReader(jsonData))
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				model, err := GetModel("beefsteak")
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to get model: %v", err)
 | |
| 				}
 | |
| 				if model.ShortName != "beefsteak:latest" {
 | |
| 					t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Show Model Handler",
 | |
| 			Method: http.MethodPost,
 | |
| 			Path:   "/api/show",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				createTestModel(t, "show-model")
 | |
| 				showReq := api.ShowRequest{Model: "show-model"}
 | |
| 				jsonData, err := json.Marshal(showReq)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to marshal show request: %v", err)
 | |
| 				}
 | |
| 				req.Body = io.NopCloser(bytes.NewReader(jsonData))
 | |
| 			},
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json; charset=utf-8" {
 | |
| 					t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var showResp api.ShowResponse
 | |
| 				err = json.Unmarshal(body, &showResp)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var params []string
 | |
| 				paramsSplit := strings.Split(showResp.Parameters, "\n")
 | |
| 				for _, p := range paramsSplit {
 | |
| 					params = append(params, strings.Join(strings.Fields(p), " "))
 | |
| 				}
 | |
| 				sort.Strings(params)
 | |
| 				expectedParams := []string{
 | |
| 					"seed 42",
 | |
| 					"stop \"bar\"",
 | |
| 					"stop \"foo\"",
 | |
| 					"top_p 0.9",
 | |
| 				}
 | |
| 				if !equalStringSlices(params, expectedParams) {
 | |
| 					t.Errorf("expected parameters %v, got %v", expectedParams, params)
 | |
| 				}
 | |
| 				paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
 | |
| 				if !ok {
 | |
| 					t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
 | |
| 				}
 | |
| 				if math.Abs(paramCount) > 1e-9 {
 | |
| 					t.Errorf("expected parameter count to be 0, got %f", paramCount)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name: "openai retrieve model handler",
 | |
| 			Setup: func(t *testing.T, req *http.Request) {
 | |
| 				createTestModel(t, "show-model")
 | |
| 			},
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/v1/models/show-model",
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				contentType := resp.Header.Get("Content-Type")
 | |
| 				if contentType != "application/json" {
 | |
| 					t.Errorf("expected content type application/json, got %s", contentType)
 | |
| 				}
 | |
| 				body, err := io.ReadAll(resp.Body)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to read response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				var m openai.Model
 | |
| 				err = json.Unmarshal(body, &m)
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("failed to unmarshal response body: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if m.Id != "show-model" || m.OwnedBy != "library" {
 | |
| 					t.Errorf("expected model 'show-model' owned by 'library', got %v", m)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			Name:   "Method Not Allowed",
 | |
| 			Method: http.MethodGet,
 | |
| 			Path:   "/api/show",
 | |
| 			Expected: func(t *testing.T, resp *http.Response) {
 | |
| 				if resp.StatusCode != 405 {
 | |
| 					t.Errorf("expected status code 405, got %d", resp.StatusCode)
 | |
| 				}
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	modelsDir := t.TempDir()
 | |
| 	t.Setenv("OLLAMA_MODELS", modelsDir)
 | |
| 
 | |
| 	rc := &ollama.Registry{
 | |
| 		// This is a temporary measure to allow us to move forward,
 | |
| 		// surfacing any code contacting ollama.com we do not intended
 | |
| 		// to.
 | |
| 		//
 | |
| 		// Currently, this only handles DELETE /api/delete, which
 | |
| 		// should not make any contact with the ollama.com registry, so
 | |
| 		// be clear about that.
 | |
| 		//
 | |
| 		// Tests that do need to contact the registry here, will be
 | |
| 		// consumed into our new server/api code packages and removed
 | |
| 		// from here.
 | |
| 		HTTPClient: panicOnRoundTrip,
 | |
| 	}
 | |
| 
 | |
| 	s := &Server{}
 | |
| 	router, err := s.GenerateRoutes(rc)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to generate routes: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	httpSrv := httptest.NewServer(router)
 | |
| 	t.Cleanup(httpSrv.Close)
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.Name, func(t *testing.T) {
 | |
| 			u := httpSrv.URL + tc.Path
 | |
| 			req, err := http.NewRequestWithContext(t.Context(), tc.Method, u, nil)
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("failed to create request: %v", err)
 | |
| 			}
 | |
| 
 | |
| 			if tc.Setup != nil {
 | |
| 				tc.Setup(t, req)
 | |
| 			}
 | |
| 
 | |
| 			resp, err := httpSrv.Client().Do(req)
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("failed to do request: %v", err)
 | |
| 			}
 | |
| 			defer resp.Body.Close()
 | |
| 
 | |
| 			if tc.Expected != nil {
 | |
| 				tc.Expected(t, resp)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func casingShuffle(s string) string {
 | |
| 	rr := []rune(s)
 | |
| 	for i := range rr {
 | |
| 		if rand.N(2) == 0 {
 | |
| 			rr[i] = unicode.ToUpper(rr[i])
 | |
| 		} else {
 | |
| 			rr[i] = unicode.ToLower(rr[i])
 | |
| 		}
 | |
| 	}
 | |
| 	return string(rr)
 | |
| }
 | |
| 
 | |
| func TestManifestCaseSensitivity(t *testing.T) {
 | |
| 	t.Setenv("OLLAMA_MODELS", t.TempDir())
 | |
| 
 | |
| 	r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		w.WriteHeader(http.StatusOK)
 | |
| 		io.WriteString(w, `{}`) //nolint:errcheck
 | |
| 	}))
 | |
| 	defer r.Close()
 | |
| 
 | |
| 	nameUsed := make(map[string]bool)
 | |
| 	name := func() string {
 | |
| 		const fqmn = "example/namespace/model:tag"
 | |
| 		for {
 | |
| 			v := casingShuffle(fqmn)
 | |
| 			if nameUsed[v] {
 | |
| 				continue
 | |
| 			}
 | |
| 			nameUsed[v] = true
 | |
| 			return v
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	wantStableName := name()
 | |
| 
 | |
| 	t.Logf("stable name: %s", wantStableName)
 | |
| 
 | |
| 	// checkManifestList tests that there is strictly one manifest in the
 | |
| 	// models directory, and that the manifest is for the model under test.
 | |
| 	checkManifestList := func() {
 | |
| 		t.Helper()
 | |
| 
 | |
| 		mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
 | |
| 		var entries []string
 | |
| 		t.Logf("dir entries:")
 | |
| 		fsys := os.DirFS(mandir)
 | |
| 		err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			t.Logf("    %s", fs.FormatDirEntry(info))
 | |
| 			if info.IsDir() {
 | |
| 				return nil
 | |
| 			}
 | |
| 			path = strings.TrimPrefix(path, mandir)
 | |
| 			entries = append(entries, path)
 | |
| 			return nil
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to walk directory: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		if len(entries) != 1 {
 | |
| 			t.Errorf("len(got) = %d, want 1", len(entries))
 | |
| 			return // do not use Fatal so following steps run
 | |
| 		}
 | |
| 
 | |
| 		g := entries[0] // raw path
 | |
| 		g = filepath.ToSlash(g)
 | |
| 		w := model.ParseName(wantStableName).Filepath()
 | |
| 		w = filepath.ToSlash(w)
 | |
| 		if g != w {
 | |
| 			t.Errorf("\ngot:  %s\nwant: %s", g, w)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	checkOK := func(w *httptest.ResponseRecorder) {
 | |
| 		t.Helper()
 | |
| 		if w.Code != http.StatusOK {
 | |
| 			t.Errorf("code = %d, want 200", w.Code)
 | |
| 			t.Logf("body: %s", w.Body.String())
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	var s Server
 | |
| 	testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
 | |
| 		var d net.Dialer
 | |
| 		return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
 | |
| 	}
 | |
| 	t.Cleanup(func() { testMakeRequestDialContext = nil })
 | |
| 
 | |
| 	t.Logf("creating")
 | |
| 	_, digest := createBinFile(t, nil, nil)
 | |
| 	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
 | |
| 		// Start with the stable name, and later use a case-shuffled
 | |
| 		// version.
 | |
| 		Name:   wantStableName,
 | |
| 		Files:  map[string]string{"test.gguf": digest},
 | |
| 		Stream: &stream,
 | |
| 	}))
 | |
| 	checkManifestList()
 | |
| 
 | |
| 	t.Logf("creating (again)")
 | |
| 	checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
 | |
| 		Name:   name(),
 | |
| 		Files:  map[string]string{"test.gguf": digest},
 | |
| 		Stream: &stream,
 | |
| 	}))
 | |
| 	checkManifestList()
 | |
| 
 | |
| 	t.Logf("pulling")
 | |
| 	checkOK(createRequest(t, s.PullHandler, api.PullRequest{
 | |
| 		Name:     name(),
 | |
| 		Stream:   &stream,
 | |
| 		Insecure: true,
 | |
| 	}))
 | |
| 	checkManifestList()
 | |
| 
 | |
| 	t.Logf("copying")
 | |
| 	checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
 | |
| 		Source:      name(),
 | |
| 		Destination: name(),
 | |
| 	}))
 | |
| 	checkManifestList()
 | |
| 
 | |
| 	t.Logf("pushing")
 | |
| 	rr := createRequest(t, s.PushHandler, api.PushRequest{
 | |
| 		Model:    name(),
 | |
| 		Insecure: true,
 | |
| 		Username: "alice",
 | |
| 		Password: "x",
 | |
| 	})
 | |
| 	checkOK(rr)
 | |
| 	if !strings.Contains(rr.Body.String(), `"status":"success"`) {
 | |
| 		t.Errorf("got = %q, want success", rr.Body.String())
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestShow(t *testing.T) {
 | |
| 	t.Setenv("OLLAMA_MODELS", t.TempDir())
 | |
| 
 | |
| 	var s Server
 | |
| 
 | |
| 	_, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
 | |
| 	_, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
 | |
| 
 | |
| 	createRequest(t, s.CreateHandler, api.CreateRequest{
 | |
| 		Name:  "show-model",
 | |
| 		Files: map[string]string{"model.gguf": digest1, "projector.gguf": digest2},
 | |
| 	})
 | |
| 
 | |
| 	w := createRequest(t, s.ShowHandler, api.ShowRequest{
 | |
| 		Name: "show-model",
 | |
| 	})
 | |
| 
 | |
| 	if w.Code != http.StatusOK {
 | |
| 		t.Fatalf("expected status code 200, actual %d", w.Code)
 | |
| 	}
 | |
| 
 | |
| 	var resp api.ShowResponse
 | |
| 	if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	if resp.ModelInfo["general.architecture"] != "test" {
 | |
| 		t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
 | |
| 	}
 | |
| 
 | |
| 	if resp.ProjectorInfo["general.architecture"] != "clip" {
 | |
| 		t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestNormalize(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		input []float32
 | |
| 	}
 | |
| 
 | |
| 	testCases := []testCase{
 | |
| 		{input: []float32{1}},
 | |
| 		{input: []float32{0, 1, 2, 3}},
 | |
| 		{input: []float32{0.1, 0.2, 0.3}},
 | |
| 		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
 | |
| 		{input: []float32{0, 0, 0}},
 | |
| 	}
 | |
| 
 | |
| 	isNormalized := func(vec []float32) (res bool) {
 | |
| 		sum := 0.0
 | |
| 		for _, v := range vec {
 | |
| 			sum += float64(v * v)
 | |
| 		}
 | |
| 		if math.Abs(sum-1) > 1e-6 {
 | |
| 			return sum == 0
 | |
| 		} else {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run("", func(t *testing.T) {
 | |
| 			normalized := normalize(tc.input)
 | |
| 			if !isNormalized(normalized) {
 | |
| 				t.Errorf("Vector %v is not normalized", tc.input)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestFilterThinkTags(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		msgs  []api.Message
 | |
| 		want  []api.Message
 | |
| 		model *Model
 | |
| 	}
 | |
| 	testCases := []testCase{
 | |
| 		{
 | |
| 			msgs: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			want: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "abc"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			model: &Model{
 | |
| 				Config: ConfigV2{
 | |
| 					ModelFamily: "qwen3",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 		// with newlines inside the think tag aned newlines after
 | |
| 		{
 | |
| 			msgs: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			want: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "abc\ndef"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			model: &Model{
 | |
| 				Config: ConfigV2{
 | |
| 					ModelFamily: "qwen3",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 		// should leave thinking tags if it's after the last user message
 | |
| 		{
 | |
| 			msgs: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "<think>Thinking...</think>after"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 				{Role: "assistant", Content: "<think>thinking again</think>hjk"},
 | |
| 				{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
 | |
| 			},
 | |
| 			want: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "after"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 				{Role: "assistant", Content: "<think>thinking again</think>hjk"},
 | |
| 				{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
 | |
| 			},
 | |
| 			model: &Model{
 | |
| 				Config: ConfigV2{
 | |
| 					ModelFamily: "qwen3",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			// shouldn't strip anything because the model family isn't one of the hardcoded ones
 | |
| 			msgs: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			want: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			model: &Model{
 | |
| 				Config: ConfigV2{
 | |
| 					ModelFamily: "llama3",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			// deepseek-r1:-prefixed model
 | |
| 			msgs: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			want: []api.Message{
 | |
| 				{Role: "user", Content: "Hello, world!"},
 | |
| 				{Role: "assistant", Content: "abc"},
 | |
| 				{Role: "user", Content: "What is the answer?"},
 | |
| 			},
 | |
| 			model: &Model{
 | |
| 				Name:      "registry.ollama.ai/library/deepseek-r1:latest",
 | |
| 				ShortName: "deepseek-r1:7b",
 | |
| 				Config:    ConfigV2{},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for i, tc := range testCases {
 | |
| 		filtered := filterThinkTags(tc.msgs, tc.model)
 | |
| 
 | |
| 		if !reflect.DeepEqual(filtered, tc.want) {
 | |
| 			t.Errorf("messages differ for case %d:", i)
 | |
| 			for i := range tc.want {
 | |
| 				if i >= len(filtered) {
 | |
| 					t.Errorf("  missing message %d: %+v", i, tc.want[i])
 | |
| 					continue
 | |
| 				}
 | |
| 				if !reflect.DeepEqual(filtered[i], tc.want[i]) {
 | |
| 					t.Errorf("  message %d:\n    want: %+v\n    got:  %+v", i, tc.want[i], filtered[i])
 | |
| 				}
 | |
| 			}
 | |
| 			if len(filtered) > len(tc.want) {
 | |
| 				for i := len(tc.want); i < len(filtered); i++ {
 | |
| 					t.Errorf("  extra message %d: %+v", i, filtered[i])
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestWaitForStream(t *testing.T) {
 | |
| 	gin.SetMode(gin.TestMode)
 | |
| 
 | |
| 	cases := []struct {
 | |
| 		name       string
 | |
| 		messages   []any
 | |
| 		expectCode int
 | |
| 		expectBody string
 | |
| 	}{
 | |
| 		{
 | |
| 			name: "error",
 | |
| 			messages: []any{
 | |
| 				gin.H{"error": "internal server error"},
 | |
| 			},
 | |
| 			expectCode: http.StatusInternalServerError,
 | |
| 			expectBody: `{"error":"internal server error"}`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "error status",
 | |
| 			messages: []any{
 | |
| 				gin.H{"status": http.StatusNotFound, "error": "not found"},
 | |
| 			},
 | |
| 			expectCode: http.StatusNotFound,
 | |
| 			expectBody: `{"error":"not found"}`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "unknown error",
 | |
| 			messages: []any{
 | |
| 				gin.H{"msg": "something else"},
 | |
| 			},
 | |
| 			expectCode: http.StatusInternalServerError,
 | |
| 			expectBody: `{"error":"unknown error"}`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "unknown type",
 | |
| 			messages: []any{
 | |
| 				struct{}{},
 | |
| 			},
 | |
| 			expectCode: http.StatusInternalServerError,
 | |
| 			expectBody: `{"error":"unknown message type"}`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "progress success",
 | |
| 			messages: []any{
 | |
| 				api.ProgressResponse{Status: "success"},
 | |
| 			},
 | |
| 			expectCode: http.StatusOK,
 | |
| 			expectBody: `{"status":"success"}`,
 | |
| 		},
 | |
| 		{
 | |
| 			name: "progress more than success",
 | |
| 			messages: []any{
 | |
| 				api.ProgressResponse{Status: "success"},
 | |
| 				api.ProgressResponse{Status: "one more thing"},
 | |
| 			},
 | |
| 			expectCode: http.StatusOK,
 | |
| 			expectBody: `{"status":"one more thing"}`,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range cases {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			w := httptest.NewRecorder()
 | |
| 			c, _ := gin.CreateTestContext(w)
 | |
| 
 | |
| 			ch := make(chan any, len(tt.messages))
 | |
| 			for _, msg := range tt.messages {
 | |
| 				ch <- msg
 | |
| 			}
 | |
| 			close(ch)
 | |
| 
 | |
| 			waitForStream(c, ch)
 | |
| 
 | |
| 			if w.Code != tt.expectCode {
 | |
| 				t.Errorf("expected status %d, got %d", tt.expectCode, w.Code)
 | |
| 			}
 | |
| 
 | |
| 			if diff := cmp.Diff(w.Body.String(), tt.expectBody); diff != "" {
 | |
| 				t.Errorf("body mismatch (-want +got):\n%s", diff)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 |