mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			131 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
| //go:build integration && models
 | |
| 
 | |
| package integration
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"log/slog"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/ollama/ollama/api"
 | |
| )
 | |
| 
 | |
| func TestQuantization(t *testing.T) {
 | |
| 	sourceModels := []string{
 | |
| 		"qwen2.5:0.5b-instruct-fp16",
 | |
| 	}
 | |
| 	quantizations := []string{
 | |
| 		"Q8_0",
 | |
| 		"Q4_K_S",
 | |
| 		"Q4_K_M",
 | |
| 		"Q4_K",
 | |
| 	}
 | |
| 	softTimeout, hardTimeout := getTimeouts(t)
 | |
| 	started := time.Now()
 | |
| 	slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
 | |
| 	defer cancel()
 | |
| 	client, _, cleanup := InitServerConnection(ctx, t)
 | |
| 	defer cleanup()
 | |
| 
 | |
| 	for _, base := range sourceModels {
 | |
| 		if err := PullIfMissing(ctx, client, base); err != nil {
 | |
| 			t.Fatalf("pull failed %s", err)
 | |
| 		}
 | |
| 		for _, quant := range quantizations {
 | |
| 			newName := fmt.Sprintf("%s__%s", base, quant)
 | |
| 			t.Run(newName, func(t *testing.T) {
 | |
| 				if time.Now().Sub(started) > softTimeout {
 | |
| 					t.Skip("skipping remaining tests to avoid excessive runtime")
 | |
| 				}
 | |
| 				req := &api.CreateRequest{
 | |
| 					Model:        newName,
 | |
| 					Quantization: quant,
 | |
| 					From:         base,
 | |
| 				}
 | |
| 				fn := func(resp api.ProgressResponse) error {
 | |
| 					// fmt.Print(".")
 | |
| 					return nil
 | |
| 				}
 | |
| 				t.Logf("quantizing: %s -> %s", base, quant)
 | |
| 				if err := client.Create(ctx, req, fn); err != nil {
 | |
| 					t.Fatalf("create failed %s", err)
 | |
| 				}
 | |
| 				defer func() {
 | |
| 					req := &api.DeleteRequest{
 | |
| 						Model: newName,
 | |
| 					}
 | |
| 					t.Logf("deleting: %s -> %s", base, quant)
 | |
| 					if err := client.Delete(ctx, req); err != nil {
 | |
| 						t.Logf("failed to clean up %s: %s", req.Model, err)
 | |
| 					}
 | |
| 				}()
 | |
| 				// Check metadata on the model
 | |
| 				resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
 | |
| 				if err != nil {
 | |
| 					t.Fatalf("unable to show model: %s", err)
 | |
| 				}
 | |
| 				if !strings.Contains(resp.Details.QuantizationLevel, quant) {
 | |
| 					t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
 | |
| 				}
 | |
| 
 | |
| 				stream := true
 | |
| 				genReq := api.GenerateRequest{
 | |
| 					Model:     newName,
 | |
| 					Prompt:    "why is the sky blue?",
 | |
| 					KeepAlive: &api.Duration{Duration: 3 * time.Second},
 | |
| 					Options: map[string]any{
 | |
| 						"seed":        42,
 | |
| 						"temperature": 0.0,
 | |
| 					},
 | |
| 					Stream: &stream,
 | |
| 				}
 | |
| 				t.Logf("verifying: %s -> %s", base, quant)
 | |
| 
 | |
| 				// Some smaller quantizations can cause models to have poor quality
 | |
| 				// or get stuck in repetition loops, so we stop as soon as we have any matches
 | |
| 				anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
 | |
| 				reqCtx, reqCancel := context.WithCancel(ctx)
 | |
| 				atLeastOne := false
 | |
| 				var buf bytes.Buffer
 | |
| 				genfn := func(response api.GenerateResponse) error {
 | |
| 					buf.Write([]byte(response.Response))
 | |
| 					fullResp := strings.ToLower(buf.String())
 | |
| 					for _, resp := range anyResp {
 | |
| 						if strings.Contains(fullResp, resp) {
 | |
| 							atLeastOne = true
 | |
| 							t.Log(fullResp)
 | |
| 							reqCancel()
 | |
| 							break
 | |
| 						}
 | |
| 					}
 | |
| 					return nil
 | |
| 				}
 | |
| 
 | |
| 				done := make(chan int)
 | |
| 				var genErr error
 | |
| 				go func() {
 | |
| 					genErr = client.Generate(reqCtx, &genReq, genfn)
 | |
| 					done <- 0
 | |
| 				}()
 | |
| 
 | |
| 				select {
 | |
| 				case <-done:
 | |
| 					if genErr != nil && !atLeastOne {
 | |
| 						t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
 | |
| 					}
 | |
| 				case <-ctx.Done():
 | |
| 					t.Error("outer test context done while waiting for generate")
 | |
| 				}
 | |
| 
 | |
| 				t.Logf("passed")
 | |
| 
 | |
| 			})
 | |
| 		}
 | |
| 	}
 | |
| }
 |