mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
	
	
		
			413 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
		
		
			
		
	
	
			413 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
|  | //go:build integration
 | ||
|  | 
 | ||
|  | package integration | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"bytes" | ||
|  | 	"context" | ||
|  | 	"fmt" | ||
|  | 	"math/rand" | ||
|  | 	"strings" | ||
|  | 	"testing" | ||
|  | 	"time" | ||
|  | 
 | ||
|  | 	"github.com/ollama/ollama/api" | ||
|  | ) | ||
|  | 
 | ||
|  | func TestAPIGenerate(t *testing.T) { | ||
|  | 	initialTimeout := 60 * time.Second | ||
|  | 	streamTimeout := 30 * time.Second | ||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) | ||
|  | 	defer cancel() | ||
|  | 	// Set up the test data
 | ||
|  | 	req := api.GenerateRequest{ | ||
|  | 		Model:  smol, | ||
|  | 		Prompt: "why is the sky blue? be brief", | ||
|  | 		Options: map[string]interface{}{ | ||
|  | 			"temperature": 0, | ||
|  | 			"seed":        123, | ||
|  | 		}, | ||
|  | 	} | ||
|  | 	anyResp := []string{"rayleigh", "scattering"} | ||
|  | 
 | ||
|  | 	client, _, cleanup := InitServerConnection(ctx, t) | ||
|  | 	defer cleanup() | ||
|  | 	if err := PullIfMissing(ctx, client, req.Model); err != nil { | ||
|  | 		t.Fatalf("pull failed %s", err) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	tests := []struct { | ||
|  | 		name   string | ||
|  | 		stream bool | ||
|  | 	}{ | ||
|  | 		{ | ||
|  | 			name:   "stream", | ||
|  | 			stream: true, | ||
|  | 		}, | ||
|  | 		{ | ||
|  | 			name:   "no_stream", | ||
|  | 			stream: false, | ||
|  | 		}, | ||
|  | 	} | ||
|  | 
 | ||
|  | 	for _, test := range tests { | ||
|  | 		t.Run(test.name, func(t *testing.T) { | ||
|  | 			stallTimer := time.NewTimer(initialTimeout) | ||
|  | 			var buf bytes.Buffer | ||
|  | 			fn := func(response api.GenerateResponse) error { | ||
|  | 				// Fields that must always be present
 | ||
|  | 				if response.Model == "" { | ||
|  | 					t.Errorf("response missing model: %#v", response) | ||
|  | 				} | ||
|  | 				if response.Done { | ||
|  | 					// Required fields for final updates:
 | ||
|  | 					if response.DoneReason == "" && *req.Stream { | ||
|  | 						// TODO - is the lack of done reason on non-stream a bug?
 | ||
|  | 						t.Errorf("final response missing done_reason: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.TotalDuration == 0 { | ||
|  | 						t.Errorf("final response missing total_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.LoadDuration == 0 { | ||
|  | 						t.Errorf("final response missing load_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.PromptEvalDuration == 0 { | ||
|  | 						t.Errorf("final response missing prompt_eval_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.EvalCount == 0 { | ||
|  | 						t.Errorf("final response missing eval_count: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.EvalDuration == 0 { | ||
|  | 						t.Errorf("final response missing eval_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if len(response.Context) == 0 { | ||
|  | 						t.Errorf("final response missing context: %#v", response) | ||
|  | 					} | ||
|  | 
 | ||
|  | 					// Note: caching can result in no prompt eval count, so this can't be verified reliably
 | ||
|  | 					// if response.Metrics.PromptEvalCount == 0 {
 | ||
|  | 					// 	t.Errorf("final response missing prompt_eval_count: %#v", response)
 | ||
|  | 					// }
 | ||
|  | 
 | ||
|  | 				} // else incremental response, nothing to check right now...
 | ||
|  | 				buf.Write([]byte(response.Response)) | ||
|  | 				if !stallTimer.Reset(streamTimeout) { | ||
|  | 					return fmt.Errorf("stall was detected while streaming response, aborting") | ||
|  | 				} | ||
|  | 				return nil | ||
|  | 			} | ||
|  | 
 | ||
|  | 			done := make(chan int) | ||
|  | 			var genErr error | ||
|  | 			go func() { | ||
|  | 				req.Stream = &test.stream | ||
|  | 				req.Options["seed"] = rand.Int() // bust cache for prompt eval results
 | ||
|  | 				genErr = client.Generate(ctx, &req, fn) | ||
|  | 				done <- 0 | ||
|  | 			}() | ||
|  | 
 | ||
|  | 			select { | ||
|  | 			case <-stallTimer.C: | ||
|  | 				if buf.Len() == 0 { | ||
|  | 					t.Errorf("generate never started.  Timed out after :%s", initialTimeout.String()) | ||
|  | 				} else { | ||
|  | 					t.Errorf("generate stalled.  Response so far:%s", buf.String()) | ||
|  | 				} | ||
|  | 			case <-done: | ||
|  | 				if genErr != nil { | ||
|  | 					t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt) | ||
|  | 				} | ||
|  | 				// Verify the response contains the expected data
 | ||
|  | 				response := buf.String() | ||
|  | 				atLeastOne := false | ||
|  | 				for _, resp := range anyResp { | ||
|  | 					if strings.Contains(strings.ToLower(response), resp) { | ||
|  | 						atLeastOne = true | ||
|  | 						break | ||
|  | 					} | ||
|  | 				} | ||
|  | 				if !atLeastOne { | ||
|  | 					t.Errorf("none of %v found in %s", anyResp, response) | ||
|  | 				} | ||
|  | 			case <-ctx.Done(): | ||
|  | 				t.Error("outer test context done while waiting for generate") | ||
|  | 			} | ||
|  | 		}) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// Validate PS while we're at it...
 | ||
|  | 	resp, err := client.ListRunning(ctx) | ||
|  | 	if err != nil { | ||
|  | 		t.Fatalf("list models API error: %s", err) | ||
|  | 	} | ||
|  | 	if resp == nil || len(resp.Models) == 0 { | ||
|  | 		t.Fatalf("list models API returned empty list while model should still be loaded") | ||
|  | 	} | ||
|  | 	// Find the model we just loaded and verify some attributes
 | ||
|  | 	found := false | ||
|  | 	for _, model := range resp.Models { | ||
|  | 		if strings.Contains(model.Name, req.Model) { | ||
|  | 			found = true | ||
|  | 			if model.Model == "" { | ||
|  | 				t.Errorf("model field omitted: %#v", model) | ||
|  | 			} | ||
|  | 			if model.Size == 0 { | ||
|  | 				t.Errorf("size omitted: %#v", model) | ||
|  | 			} | ||
|  | 			if model.Digest == "" { | ||
|  | 				t.Errorf("digest omitted: %#v", model) | ||
|  | 			} | ||
|  | 			verifyModelDetails(t, model.Details) | ||
|  | 			var nilTime time.Time | ||
|  | 			if model.ExpiresAt == nilTime { | ||
|  | 				t.Errorf("expires_at omitted: %#v", model) | ||
|  | 			} | ||
|  | 			// SizeVRAM could be zero.
 | ||
|  | 		} | ||
|  | 	} | ||
|  | 	if !found { | ||
|  | 		t.Errorf("unable to locate running model: %#v", resp) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | func TestAPIChat(t *testing.T) { | ||
|  | 	initialTimeout := 60 * time.Second | ||
|  | 	streamTimeout := 30 * time.Second | ||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) | ||
|  | 	defer cancel() | ||
|  | 	// Set up the test data
 | ||
|  | 	req := api.ChatRequest{ | ||
|  | 		Model: smol, | ||
|  | 		Messages: []api.Message{ | ||
|  | 			{ | ||
|  | 				Role:    "user", | ||
|  | 				Content: "why is the sky blue?  be brief", | ||
|  | 			}, | ||
|  | 		}, | ||
|  | 		Options: map[string]interface{}{ | ||
|  | 			"temperature": 0, | ||
|  | 			"seed":        123, | ||
|  | 		}, | ||
|  | 	} | ||
|  | 	anyResp := []string{"rayleigh", "scattering"} | ||
|  | 
 | ||
|  | 	client, _, cleanup := InitServerConnection(ctx, t) | ||
|  | 	defer cleanup() | ||
|  | 	if err := PullIfMissing(ctx, client, req.Model); err != nil { | ||
|  | 		t.Fatalf("pull failed %s", err) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	tests := []struct { | ||
|  | 		name   string | ||
|  | 		stream bool | ||
|  | 	}{ | ||
|  | 		{ | ||
|  | 			name:   "stream", | ||
|  | 			stream: true, | ||
|  | 		}, | ||
|  | 		{ | ||
|  | 			name:   "no_stream", | ||
|  | 			stream: false, | ||
|  | 		}, | ||
|  | 	} | ||
|  | 
 | ||
|  | 	for _, test := range tests { | ||
|  | 		t.Run(test.name, func(t *testing.T) { | ||
|  | 			stallTimer := time.NewTimer(initialTimeout) | ||
|  | 			var buf bytes.Buffer | ||
|  | 			fn := func(response api.ChatResponse) error { | ||
|  | 				// Fields that must always be present
 | ||
|  | 				if response.Model == "" { | ||
|  | 					t.Errorf("response missing model: %#v", response) | ||
|  | 				} | ||
|  | 				if response.Done { | ||
|  | 					// Required fields for final updates:
 | ||
|  | 					var nilTime time.Time | ||
|  | 					if response.CreatedAt == nilTime { | ||
|  | 						t.Errorf("final response missing total_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.DoneReason == "" { | ||
|  | 						t.Errorf("final response missing done_reason: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.TotalDuration == 0 { | ||
|  | 						t.Errorf("final response missing total_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.LoadDuration == 0 { | ||
|  | 						t.Errorf("final response missing load_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.PromptEvalDuration == 0 { | ||
|  | 						t.Errorf("final response missing prompt_eval_duration: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.EvalCount == 0 { | ||
|  | 						t.Errorf("final response missing eval_count: %#v", response) | ||
|  | 					} | ||
|  | 					if response.Metrics.EvalDuration == 0 { | ||
|  | 						t.Errorf("final response missing eval_duration: %#v", response) | ||
|  | 					} | ||
|  | 
 | ||
|  | 					if response.Metrics.PromptEvalCount == 0 { | ||
|  | 						t.Errorf("final response missing prompt_eval_count: %#v", response) | ||
|  | 					} | ||
|  | 				} // else incremental response, nothing to check right now...
 | ||
|  | 				buf.Write([]byte(response.Message.Content)) | ||
|  | 				if !stallTimer.Reset(streamTimeout) { | ||
|  | 					return fmt.Errorf("stall was detected while streaming response, aborting") | ||
|  | 				} | ||
|  | 				return nil | ||
|  | 			} | ||
|  | 
 | ||
|  | 			done := make(chan int) | ||
|  | 			var genErr error | ||
|  | 			go func() { | ||
|  | 				req.Stream = &test.stream | ||
|  | 				req.Options["seed"] = rand.Int() // bust cache for prompt eval results
 | ||
|  | 				genErr = client.Chat(ctx, &req, fn) | ||
|  | 				done <- 0 | ||
|  | 			}() | ||
|  | 
 | ||
|  | 			select { | ||
|  | 			case <-stallTimer.C: | ||
|  | 				if buf.Len() == 0 { | ||
|  | 					t.Errorf("chat never started.  Timed out after :%s", initialTimeout.String()) | ||
|  | 				} else { | ||
|  | 					t.Errorf("chat stalled.  Response so far:%s", buf.String()) | ||
|  | 				} | ||
|  | 			case <-done: | ||
|  | 				if genErr != nil { | ||
|  | 					t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages) | ||
|  | 				} | ||
|  | 				// Verify the response contains the expected data
 | ||
|  | 				response := buf.String() | ||
|  | 				atLeastOne := false | ||
|  | 				for _, resp := range anyResp { | ||
|  | 					if strings.Contains(strings.ToLower(response), resp) { | ||
|  | 						atLeastOne = true | ||
|  | 						break | ||
|  | 					} | ||
|  | 				} | ||
|  | 				if !atLeastOne { | ||
|  | 					t.Errorf("none of %v found in %s", anyResp, response) | ||
|  | 				} | ||
|  | 			case <-ctx.Done(): | ||
|  | 				t.Error("outer test context done while waiting for chat") | ||
|  | 			} | ||
|  | 		}) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | func TestAPIListModels(t *testing.T) { | ||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||
|  | 	defer cancel() | ||
|  | 	client, _, cleanup := InitServerConnection(ctx, t) | ||
|  | 	defer cleanup() | ||
|  | 
 | ||
|  | 	// Make sure we have at least one model so an empty list can be considered a failure
 | ||
|  | 	if err := PullIfMissing(ctx, client, smol); err != nil { | ||
|  | 		t.Fatalf("pull failed %s", err) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	resp, err := client.List(ctx) | ||
|  | 	if err != nil { | ||
|  | 		t.Fatalf("unable to list models: %s", err) | ||
|  | 	} | ||
|  | 	if len(resp.Models) == 0 { | ||
|  | 		t.Fatalf("list should not be empty") | ||
|  | 	} | ||
|  | 	model := resp.Models[0] | ||
|  | 	if model.Name == "" { | ||
|  | 		t.Errorf("first model name empty: %#v", model) | ||
|  | 	} | ||
|  | 	var nilTime time.Time | ||
|  | 	if model.ModifiedAt == nilTime { | ||
|  | 		t.Errorf("first model modified_at empty: %#v", model) | ||
|  | 	} | ||
|  | 	if model.Size == 0 { | ||
|  | 		t.Errorf("first model size empty: %#v", model) | ||
|  | 	} | ||
|  | 	if model.Digest == "" { | ||
|  | 		t.Errorf("first model digest empty: %#v", model) | ||
|  | 	} | ||
|  | 	verifyModelDetails(t, model.Details) | ||
|  | } | ||
|  | 
 | ||
|  | func verifyModelDetails(t *testing.T, details api.ModelDetails) { | ||
|  | 	if details.Format == "" { | ||
|  | 		t.Errorf("first model details.format empty: %#v", details) | ||
|  | 	} | ||
|  | 	if details.Family == "" { | ||
|  | 		t.Errorf("first model details.family empty: %#v", details) | ||
|  | 	} | ||
|  | 	if details.ParameterSize == "" { | ||
|  | 		t.Errorf("first model details.parameter_size empty: %#v", details) | ||
|  | 	} | ||
|  | 	if details.QuantizationLevel == "" { | ||
|  | 		t.Errorf("first model details.quantization_level empty: %#v", details) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | func TestAPIShowModel(t *testing.T) { | ||
|  | 	modelName := "llama3.2" | ||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) | ||
|  | 	defer cancel() | ||
|  | 	client, _, cleanup := InitServerConnection(ctx, t) | ||
|  | 	defer cleanup() | ||
|  | 
 | ||
|  | 	if err := PullIfMissing(ctx, client, modelName); err != nil { | ||
|  | 		t.Fatalf("pull failed %s", err) | ||
|  | 	} | ||
|  | 	resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName}) | ||
|  | 	if err != nil { | ||
|  | 		t.Fatalf("unable to show model: %s", err) | ||
|  | 	} | ||
|  | 	if resp.License == "" { | ||
|  | 		t.Errorf("%s missing license: %#v", modelName, resp) | ||
|  | 	} | ||
|  | 	if resp.Modelfile == "" { | ||
|  | 		t.Errorf("%s missing modelfile: %#v", modelName, resp) | ||
|  | 	} | ||
|  | 	if resp.Parameters == "" { | ||
|  | 		t.Errorf("%s missing parameters: %#v", modelName, resp) | ||
|  | 	} | ||
|  | 	if resp.Template == "" { | ||
|  | 		t.Errorf("%s missing template: %#v", modelName, resp) | ||
|  | 	} | ||
|  | 	// llama3 omits system
 | ||
|  | 	verifyModelDetails(t, resp.Details) | ||
|  | 	// llama3 ommits messages
 | ||
|  | 	if len(resp.ModelInfo) == 0 { | ||
|  | 		t.Errorf("%s missing model_info: %#v", modelName, resp) | ||
|  | 	} | ||
|  | 	// llama3 omits projectors
 | ||
|  | 	var nilTime time.Time | ||
|  | 	if resp.ModifiedAt == nilTime { | ||
|  | 		t.Errorf("%s missing modified_at: %#v", modelName, resp) | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | func TestAPIEmbeddings(t *testing.T) { | ||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) | ||
|  | 	defer cancel() | ||
|  | 	client, _, cleanup := InitServerConnection(ctx, t) | ||
|  | 	defer cleanup() | ||
|  | 	req := api.EmbeddingRequest{ | ||
|  | 		Model:  "orca-mini", | ||
|  | 		Prompt: "why is the sky blue?", | ||
|  | 		Options: map[string]interface{}{ | ||
|  | 			"temperature": 0, | ||
|  | 			"seed":        123, | ||
|  | 		}, | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if err := PullIfMissing(ctx, client, req.Model); err != nil { | ||
|  | 		t.Fatalf("pull failed %s", err) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	resp, err := client.Embeddings(ctx, &req) | ||
|  | 	if err != nil { | ||
|  | 		t.Fatalf("embeddings call failed %s", err) | ||
|  | 	} | ||
|  | 	if len(resp.Embedding) == 0 { | ||
|  | 		t.Errorf("zero length embedding response") | ||
|  | 	} | ||
|  | } |