mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
	
	
		
			267 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
		
		
			
		
	
	
			267 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
|  | //go:build integration && perf
 | ||
|  | 
 | ||
|  | package integration | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"context" | ||
|  | 	"fmt" | ||
|  | 	"io/ioutil" | ||
|  | 	"log/slog" | ||
|  | 	"math" | ||
|  | 	"os" | ||
|  | 	"path/filepath" | ||
|  | 	"strconv" | ||
|  | 	"strings" | ||
|  | 	"testing" | ||
|  | 	"time" | ||
|  | 
 | ||
|  | 	"github.com/ollama/ollama/api" | ||
|  | 	"github.com/ollama/ollama/format" | ||
|  | ) | ||
|  | 
 | ||
|  | var ( | ||
|  | 	// Models that don't work reliably with the large context prompt in this test case
 | ||
|  | 	longContextFlakes = []string{ | ||
|  | 		"granite-code:latest", | ||
|  | 		"nemotron-mini:latest", | ||
|  | 		"falcon:latest",  // 2k model
 | ||
|  | 		"falcon2:latest", // 2k model
 | ||
|  | 		"minicpm-v:latest", | ||
|  | 		"qwen:latest", | ||
|  | 		"solar-pro:latest", | ||
|  | 	} | ||
|  | ) | ||
|  | 
 | ||
|  | // Note: this test case can take a long time to run, particularly on models with
 | ||
|  | // large contexts.  Run with -timeout set to a large value to get reasonable coverage
 | ||
|  | // Example usage:
 | ||
|  | //
 | ||
|  | // go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
 | ||
|  | // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
 | ||
|  | // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
 | ||
|  | func TestModelsPerf(t *testing.T) { | ||
|  | 	softTimeout, hardTimeout := getTimeouts(t) | ||
|  | 	slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) | ||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) | ||
|  | 	defer cancel() | ||
|  | 	client, _, cleanup := InitServerConnection(ctx, t) | ||
|  | 	defer cleanup() | ||
|  | 
 | ||
|  | 	// TODO use info API eventually
 | ||
|  | 	var maxVram uint64 | ||
|  | 	var err error | ||
|  | 	if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { | ||
|  | 		maxVram, err = strconv.ParseUint(s, 10, 64) | ||
|  | 		if err != nil { | ||
|  | 			t.Fatalf("invalid  OLLAMA_MAX_VRAM %v", err) | ||
|  | 		} | ||
|  | 	} else { | ||
|  | 		slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...") | ||
|  | 	} | ||
|  | 
 | ||
|  | 	data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt")) | ||
|  | 	if err != nil { | ||
|  | 		t.Fatalf("failed to open test data file: %s", err) | ||
|  | 	} | ||
|  | 	longPrompt := "summarize the following: " + string(data) | ||
|  | 
 | ||
|  | 	var chatModels []string | ||
|  | 	if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { | ||
|  | 		chatModels = ollamaEngineChatModels | ||
|  | 	} else { | ||
|  | 		chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	for _, model := range chatModels { | ||
|  | 		t.Run(model, func(t *testing.T) { | ||
|  | 			if time.Now().Sub(started) > softTimeout { | ||
|  | 				t.Skip("skipping remaining tests to avoid excessive runtime") | ||
|  | 			} | ||
|  | 			if err := PullIfMissing(ctx, client, model); err != nil { | ||
|  | 				t.Fatalf("pull failed %s", err) | ||
|  | 			} | ||
|  | 			var maxContext int | ||
|  | 
 | ||
|  | 			resp, err := client.Show(ctx, &api.ShowRequest{Model: model}) | ||
|  | 			if err != nil { | ||
|  | 				t.Fatalf("show failed: %s", err) | ||
|  | 			} | ||
|  | 			arch := resp.ModelInfo["general.architecture"].(string) | ||
|  | 			maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) | ||
|  | 
 | ||
|  | 			if maxVram > 0 { | ||
|  | 				resp, err := client.List(ctx) | ||
|  | 				if err != nil { | ||
|  | 					t.Fatalf("list models failed %v", err) | ||
|  | 				} | ||
|  | 				for _, m := range resp.Models { | ||
|  | 					// For these tests we want to exercise a some amount of overflow on the CPU
 | ||
|  | 					if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) { | ||
|  | 						t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram))) | ||
|  | 					} | ||
|  | 				} | ||
|  | 			} | ||
|  | 			slog.Info("scneario", "model", model, "max_context", maxContext) | ||
|  | 			loaded := false | ||
|  | 			defer func() { | ||
|  | 				// best effort unload once we're done with the model
 | ||
|  | 				if loaded { | ||
|  | 					client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil }) | ||
|  | 				} | ||
|  | 			}() | ||
|  | 
 | ||
|  | 			// Some models don't handle the long context data well so skip them to avoid flaky test results
 | ||
|  | 			longContextFlake := false | ||
|  | 			for _, flake := range longContextFlakes { | ||
|  | 				if model == flake { | ||
|  | 					longContextFlake = true | ||
|  | 					break | ||
|  | 				} | ||
|  | 			} | ||
|  | 
 | ||
|  | 			// iterate through a few context sizes for coverage without excessive runtime
 | ||
|  | 			var contexts []int | ||
|  | 			keepGoing := true | ||
|  | 			if maxContext > 16384 { | ||
|  | 				contexts = []int{4096, 8192, 16384, maxContext} | ||
|  | 			} else if maxContext > 8192 { | ||
|  | 				contexts = []int{4096, 8192, maxContext} | ||
|  | 			} else if maxContext > 4096 { | ||
|  | 				contexts = []int{4096, maxContext} | ||
|  | 			} else if maxContext > 0 { | ||
|  | 				contexts = []int{maxContext} | ||
|  | 			} else { | ||
|  | 				t.Fatal("unknown max context size") | ||
|  | 			} | ||
|  | 			for _, numCtx := range contexts { | ||
|  | 				if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
 | ||
|  | 					break | ||
|  | 				} | ||
|  | 				skipLongPrompt := false | ||
|  | 
 | ||
|  | 				// Workaround bug 11172 temporarily...
 | ||
|  | 				maxPrompt := longPrompt | ||
|  | 				// If we fill the context too full with the prompt, many models
 | ||
|  | 				// quickly hit context shifting and go bad.
 | ||
|  | 				if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
 | ||
|  | 					maxPrompt = maxPrompt[:numCtx*2] | ||
|  | 				} | ||
|  | 
 | ||
|  | 				testCases := []struct { | ||
|  | 					prompt  string | ||
|  | 					anyResp []string | ||
|  | 				}{ | ||
|  | 					{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}}, | ||
|  | 					{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}}, | ||
|  | 				} | ||
|  | 				var gpuPercent int | ||
|  | 				for _, tc := range testCases { | ||
|  | 					if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) { | ||
|  | 						slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent) | ||
|  | 						continue | ||
|  | 					} | ||
|  | 					req := api.GenerateRequest{ | ||
|  | 						Model:     model, | ||
|  | 						Prompt:    tc.prompt, | ||
|  | 						KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
 | ||
|  | 						Options: map[string]interface{}{ | ||
|  | 							"temperature": 0, | ||
|  | 							"seed":        123, | ||
|  | 							"num_ctx":     numCtx, | ||
|  | 						}, | ||
|  | 					} | ||
|  | 					atLeastOne := false | ||
|  | 					var resp api.GenerateResponse | ||
|  | 
 | ||
|  | 					stream := false | ||
|  | 					req.Stream = &stream | ||
|  | 
 | ||
|  | 					// Avoid potentially getting stuck indefinitely
 | ||
|  | 					limit := 5 * time.Minute | ||
|  | 					genCtx, cancel := context.WithDeadlineCause( | ||
|  | 						ctx, | ||
|  | 						time.Now().Add(limit), | ||
|  | 						fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit), | ||
|  | 					) | ||
|  | 					defer cancel() | ||
|  | 
 | ||
|  | 					err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error { | ||
|  | 						resp = rsp | ||
|  | 						return nil | ||
|  | 					}) | ||
|  | 					if err != nil { | ||
|  | 						// Avoid excessive test runs, but don't consider a failure with massive context
 | ||
|  | 						if numCtx > 16384 && strings.Contains(err.Error(), "took longer") { | ||
|  | 							slog.Warn("max context was taking too long, skipping", "error", err) | ||
|  | 							keepGoing = false | ||
|  | 							skipLongPrompt = true | ||
|  | 							continue | ||
|  | 						} | ||
|  | 						t.Fatalf("generate error: ctx:%d err:%s", numCtx, err) | ||
|  | 					} | ||
|  | 					loaded = true | ||
|  | 					for _, expResp := range tc.anyResp { | ||
|  | 						if strings.Contains(strings.ToLower(resp.Response), expResp) { | ||
|  | 							atLeastOne = true | ||
|  | 							break | ||
|  | 						} | ||
|  | 					} | ||
|  | 					if !atLeastOne { | ||
|  | 						t.Fatalf("response didn't contain expected values: ctx:%d  expected:%v response:%s ", numCtx, tc.anyResp, resp.Response) | ||
|  | 					} | ||
|  | 					models, err := client.ListRunning(ctx) | ||
|  | 					if err != nil { | ||
|  | 						slog.Warn("failed to list running models", "error", err) | ||
|  | 						continue | ||
|  | 					} | ||
|  | 					if len(models.Models) > 1 { | ||
|  | 						slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models) | ||
|  | 					} | ||
|  | 					for _, m := range models.Models { | ||
|  | 						if m.Name == model { | ||
|  | 							if m.SizeVRAM == 0 { | ||
|  | 								slog.Info("Model fully loaded into CPU") | ||
|  | 								gpuPercent = 0 | ||
|  | 								keepGoing = false | ||
|  | 								skipLongPrompt = true | ||
|  | 							} else if m.SizeVRAM == m.Size { | ||
|  | 								slog.Info("Model fully loaded into GPU") | ||
|  | 								gpuPercent = 100 | ||
|  | 							} else { | ||
|  | 								sizeCPU := m.Size - m.SizeVRAM | ||
|  | 								cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100) | ||
|  | 								gpuPercent = int(100 - cpuPercent) | ||
|  | 								slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent) | ||
|  | 								keepGoing = false | ||
|  | 
 | ||
|  | 								// Heuristic to avoid excessive test run time
 | ||
|  | 								if gpuPercent < 90 { | ||
|  | 									skipLongPrompt = true | ||
|  | 								} | ||
|  | 							} | ||
|  | 						} | ||
|  | 					} | ||
|  | 					fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", | ||
|  | 						"MODEL", | ||
|  | 						"CONTEXT", | ||
|  | 						"GPU PERCENT", | ||
|  | 						"PROMPT COUNT", | ||
|  | 						"LOAD TIME", | ||
|  | 						"PROMPT EVAL TPS", | ||
|  | 						"EVAL TPS", | ||
|  | 					) | ||
|  | 					fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n", | ||
|  | 						model, | ||
|  | 						numCtx, | ||
|  | 						gpuPercent, | ||
|  | 						resp.PromptEvalCount, | ||
|  | 						float64(resp.LoadDuration)/1000000000.0, | ||
|  | 						float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), | ||
|  | 						float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), | ||
|  | 					) | ||
|  | 				} | ||
|  | 			} | ||
|  | 		}) | ||
|  | 	} | ||
|  | } |