test: improve scheduler/concurrency stress tests (#11906)

* test: improve scheduler/concurrency stress tests

The scheduler test used to use approximate memory figures and would often
over or under shoot a systems capcity leading to flaky test results.
This should improve the reliability of this scenario by leveraging
ps output to determinie exactly how many models it takes to
trigger thrashing.

The concurrency test is also refined to target num_parallel + 1 and handle
timeouts better.

With these refinements, TestMultiModelConcurrency was redundant

* test: add parallel generate with history

TestGenerateWithHistory will help verify caching and context
are properly handled while making requests

* test: focus embed tests on embedding models

remove non-embedding models from the embedding tests
This commit is contained in:
Daniel Hiltgen 2025-08-15 14:37:54 -07:00 committed by Richard Lyons
parent 8053d20fe6
commit 1d18f2de74
4 changed files with 169 additions and 230 deletions

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"math/rand"
"os" "os"
"strconv" "strconv"
"sync" "sync"
@ -16,245 +17,157 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
) )
func TestMultiModelConcurrency(t *testing.T) { // Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
var ( func TestConcurrentGenerate(t *testing.T) {
req = [2]api.GenerateRequest{ // Assumes all requests have the same model
{
Model: smol,
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "qwen3:0.6b",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
{"sunlight"},
{"england", "english", "massachusetts", "pilgrims", "british", "festival"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for i := 0; i < len(req); i++ {
require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
}
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
// Note: CPU based inference can crawl so don't give up too quickly
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 30*time.Second)
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredict(t *testing.T) {
req, resp := GenerateRequests() req, resp := GenerateRequests()
reqLimit := len(req) numParallel := int(envconfig.NumParallel() + 1)
iterLimit := 5 iterLimit := 3
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { softTimeout, hardTimeout := getTimeouts(t)
maxVram, err := strconv.ParseUint(s, 10, 64) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
require.NoError(t, err)
// Don't hammer on small VRAM cards...
if maxVram < 4*format.GibiByte {
reqLimit = min(reqLimit, 2)
iterLimit = 2
}
}
ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request // Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second) slog.Info("loading", "model", req[0].Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req[0].Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req[0].Model, err)
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(reqLimit) r := rand.New(rand.NewSource(0))
for i := 0; i < reqLimit; i++ { wg.Add(numParallel)
for i := range numParallel {
go func(i int) { go func(i int) {
defer wg.Done() defer wg.Done()
for j := 0; j < iterLimit; j++ { for j := 0; j < iterLimit; j++ {
slog.Info("Starting", "req", i, "iter", j) if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
k := r.Int() % len(req)
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests // On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout // so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second) DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
} }
}(i) }(i)
} }
wg.Wait() wg.Wait()
} }
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit // Stress the scheduler and attempt to load more models than will fit to cause thrashing
// This test will always load at least 2 models even on CPU based systems
func TestMultiModelStress(t *testing.T) { func TestMultiModelStress(t *testing.T) {
s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM s := os.Getenv("OLLAMA_MAX_VRAM")
if s == "" { if s == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") s = "0"
} }
maxVram, err := strconv.ParseUint(s, 10, 64) maxVram, err := strconv.ParseUint(s, 10, 64)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if maxVram < 2*format.GibiByte {
t.Skip("VRAM less than 2G, skipping model stress tests") smallModels := []string{
"llama3.2:1b",
"qwen3:0.6b",
"gemma:2b",
"deepseek-r1:1.5b",
"starcoder2:3b",
}
mediumModels := []string{
"qwen3:8b",
"llama2",
"deepseek-r1:7b",
"mistral",
"dolphin-mistral",
"gemma:7b",
"codellama:7b",
} }
type model struct { var chosenModels []string
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "llama3.2:1b",
size: 2876 * format.MebiByte,
},
{
name: "qwen3:0.6b",
size: 1600 * format.MebiByte,
},
{
name: "gemma:2b",
size: 2364 * format.MebiByte,
},
{
name: "deepseek-r1:1.5b",
size: 2048 * format.MebiByte,
},
{
name: "starcoder2:3b",
size: 2166 * format.MebiByte,
},
}
mediumModels := []model{
{
name: "qwen3:8b",
size: 6600 * format.MebiByte,
},
{
name: "llama2",
size: 5118 * format.MebiByte,
},
{
name: "deepseek-r1:7b",
size: 5600 * format.MebiByte,
},
{
name: "mistral",
size: 4620 * format.MebiByte,
},
{
name: "dolphin-mistral",
size: 4620 * format.MebiByte,
},
{
name: "gemma:7b",
size: 5000 * format.MebiByte,
},
{
name: "codellama:7b",
size: 5118 * format.MebiByte,
},
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
// size: 7400 * format.MebiByte,
// },
// {
// name: "codellama:13b",
// size: 7400 * format.MebiByte,
// },
// {
// name: "orca-mini:13b",
// size: 7400 * format.MebiByte,
// },
// {
// name: "gemma:7b",
// size: 5000 * format.MebiByte,
// },
// {
// name: "starcoder2:15b",
// size: 9100 * format.MebiByte,
// },
// }
var chosenModels []model
switch { switch {
case maxVram < 10000*format.MebiByte: case maxVram < 10000*format.MebiByte:
slog.Info("selecting small models") slog.Info("selecting small models")
chosenModels = smallModels chosenModels = smallModels
// case maxVram < 30000*format.MebiByte:
default: default:
slog.Info("selecting medium models") slog.Info("selecting medium models")
chosenModels = mediumModels chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largeModels
} }
req, resp := GenerateRequests() softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
// Make sure all the models are pulled before we get started // Make sure all the models are pulled before we get started
for _, r := range req { for _, model := range chosenModels {
require.NoError(t, PullIfMissing(ctx, client, r.Model)) require.NoError(t, PullIfMissing(ctx, client, model))
} }
var wg sync.WaitGroup // Determine how many models we can load in parallel before we exceed VRAM
consumed := uint64(256 * format.MebiByte) // Assume some baseline usage // The intent is to go 1 over what can fit so we force the scheduler to thrash
for i := 0; i < len(req); i++ { targetLoadCount := 0
// Always get at least 2 models, but don't overshoot VRAM too much or we'll take too long slog.Info("Loading models to find how many can fit in VRAM before overflowing")
if i > 1 && consumed > maxVram { for i, model := range chosenModels {
slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed)) req := &api.GenerateRequest{Model: model}
slog.Info("loading", "model", model)
err = client.Generate(ctx, req, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", model, err)
}
targetLoadCount++
if i > 0 {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
}
if len(models.Models) < targetLoadCount {
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
}
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
break break
} }
consumed += chosenModels[i].size }
slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed)) }
if targetLoadCount == len(chosenModels) {
// TODO consider retrying the medium models
slog.Warn("all models being used without exceeding VRAM, set OLLAMA_MAX_VRAM so test can pick larger models")
}
r := rand.New(rand.NewSource(0))
var wg sync.WaitGroup
for i := range targetLoadCount {
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
defer wg.Done() defer wg.Done()
reqs, resps := GenerateRequests()
for j := 0; j < 3; j++ { for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model) if time.Now().Sub(started) > softTimeout {
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second) slog.Info("exceeded soft timeout, winding down test")
return
}
k := r.Int() % len(reqs)
reqs[k].Model = chosenModels[i]
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
DoGenerate(ctx, t, client, reqs[k], resps[k],
120*time.Second, // Be extra patient for the model to load initially
10*time.Second, // Once results start streaming, fail if they stall
)
} }
}(i) }(i)
} }

View File

@ -4,6 +4,8 @@ package integration
import ( import (
"context" "context"
"log/slog"
"sync"
"testing" "testing"
"time" "time"
@ -63,3 +65,51 @@ func TestContextExhaustion(t *testing.T) {
} }
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
} }
// Send multiple requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests()
numParallel := 2
iterLimit := 2
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
wg.Add(numParallel)
for i := range numParallel {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
req[k].Context = c
req[k].Prompt = "tell me more!"
}
}(i)
}
wg.Wait()
}

File diff suppressed because one or more lines are too long

View File

@ -472,15 +472,19 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRe
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
} }
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) { func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
stallTimer := time.NewTimer(initialTimeout) stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer var buf bytes.Buffer
var context []int
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
// fmt.Print(".") // fmt.Print(".")
buf.Write([]byte(response.Response)) buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) { if !stallTimer.Reset(streamTimeout) {
return errors.New("stall was detected while streaming response, aborting") return errors.New("stall was detected while streaming response, aborting")
} }
if len(response.Context) > 0 {
context = response.Context
}
return nil return nil
} }
@ -503,7 +507,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
case <-done: case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") { if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr) slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
return return context
} }
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data // Verify the response contains the expected data
@ -520,6 +524,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for generate") t.Error("outer test context done while waiting for generate")
} }
return context
} }
// Generate a set of requests // Generate a set of requests
@ -528,55 +533,35 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{ return []api.GenerateRequest{
{ {
Model: smol, Model: smol,
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, { }, {
Model: smol, Model: smol,
Prompt: "why is the color of dirt brown?", Prompt: "why is the color of dirt brown? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, { }, {
Model: smol, Model: smol,
Prompt: "what is the origin of the us thanksgiving holiday?", Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, { }, {
Model: smol, Model: smol,
Prompt: "what is the origin of independence day?", Prompt: "what is the origin of independence day? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, { }, {
Model: smol, Model: smol,
Prompt: "what is the composition of air?", Prompt: "what is the composition of air? Be brief but factual in your reply",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, },
}, },
[][]string{ [][]string{
{"sunlight", "scattering", "interact"}, {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "british"}, {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"}, {"nitrogen", "oxygen", "carbon", "dioxide"},
} }