mirror of https://github.com/ollama/ollama.git
				
				
				
			
		
			
				
	
	
		
			175 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			175 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
| package sample
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"math"
 | |
| 	"math/rand/v2"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"testing"
 | |
| 
 | |
| 	"github.com/ollama/ollama/model"
 | |
| )
 | |
| 
 | |
| func TestWeighted(t *testing.T) {
 | |
| 	logits := []float32{-10, 3, -10, -10}
 | |
| 	sampler := NewSampler(0, 0, 0, 0, 0, nil)
 | |
| 	got, err := sampler.Sample(logits)
 | |
| 	if err != nil {
 | |
| 		t.Error(err)
 | |
| 		return
 | |
| 	}
 | |
| 	want := int32(1)
 | |
| 	if want != got {
 | |
| 		t.Errorf("index mismatch: want %d, got %d", want, got)
 | |
| 	}
 | |
| 
 | |
| 	logits = []float32{-100, -10, 0, 10}
 | |
| 	sampler = NewSampler(0, 0, 0, 0, 0, nil)
 | |
| 	got, err = sampler.Sample(logits)
 | |
| 	if err != nil {
 | |
| 		t.Error(err)
 | |
| 		return
 | |
| 	}
 | |
| 	want = int32(3) // Should pick highest probability with this r value
 | |
| 	if want != got {
 | |
| 		t.Errorf("index mismatch: want %d, got %d", want, got)
 | |
| 	}
 | |
| 
 | |
| 	// Test very high p
 | |
| 	logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
 | |
| 	// Use extremely small topP to filter out all tokens
 | |
| 	sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
 | |
| 	got, err = sampler.Sample(logits)
 | |
| 	if err != nil {
 | |
| 		t.Error(err)
 | |
| 		return
 | |
| 	}
 | |
| 	// Should get the token with the highest logit
 | |
| 	want = int32(0)
 | |
| 	if want != got {
 | |
| 		t.Errorf("index mismatch: want %d, got %d", want, got)
 | |
| 	}
 | |
| 
 | |
| 	logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
 | |
| 	sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
 | |
| 	got, err = sampler.Sample(logits)
 | |
| 	if err == nil {
 | |
| 		t.Errorf("expected error, got %d", got)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func modelHelper(t testing.TB) model.BytePairEncoding {
 | |
| 	t.Helper()
 | |
| 
 | |
| 	f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	defer f.Close()
 | |
| 
 | |
| 	vocab := make(map[string]int32)
 | |
| 	if err := json.NewDecoder(f).Decode(&vocab); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	tokens := make([]string, len(vocab))
 | |
| 	for token, id := range vocab {
 | |
| 		tokens[id] = token
 | |
| 	}
 | |
| 
 | |
| 	merges := make([]string, 0, 1)
 | |
| 	// Only need vocab for Grammar Test
 | |
| 	return model.NewBytePairEncoding(
 | |
| 		&model.Vocabulary{
 | |
| 			Values: tokens,
 | |
| 			Types:  make([]int32, len(vocab)),
 | |
| 			Merges: merges,
 | |
| 		},
 | |
| 	)
 | |
| }
 | |
| 
 | |
| func TestGrammar(t *testing.T) {
 | |
| 	tokenizer := modelHelper(t)
 | |
| 
 | |
| 	grammarJSON := `
 | |
| 	root   ::= object
 | |
| 	value  ::= object | array | string | number | ("true" | "false" | "null") ws
 | |
| 	object ::=
 | |
| 	"{" ws (
 | |
| 				string ":" ws value
 | |
| 		("," ws string ":" ws value)*
 | |
| 	)? "}" ws
 | |
| 	array  ::=
 | |
| 	"[" ws (
 | |
| 				value
 | |
| 		("," ws value)*
 | |
| 	)? "]" ws
 | |
| 	string ::=
 | |
| 	"\"" (
 | |
| 		[^"\\\x7F\x00-\x1F] |
 | |
| 		"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
 | |
| 	)* "\"" ws
 | |
| 	number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
 | |
| 	# Optional space: by convention, applied in this grammar after literal chars when allowed
 | |
| 	ws ::= ([ \t\n] ws)?
 | |
| 	`
 | |
| 	grammar, err := NewGrammarSampler(tokenizer, grammarJSON)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	defer grammar.Free()
 | |
| 
 | |
| 	logits := make([]float32, len(tokenizer.Vocabulary().Values))
 | |
| 	for i := range logits {
 | |
| 		logits[i] = rand.Float32()
 | |
| 	}
 | |
| 	tokens := make([]token, len(logits))
 | |
| 	for i := range tokens {
 | |
| 		tokens[i].id = int32(i)
 | |
| 		tokens[i].value = logits[i]
 | |
| 	}
 | |
| 
 | |
| 	grammar.Apply(tokens)
 | |
| 	nonInfCount := 0
 | |
| 	infCount := 0
 | |
| 	for _, tok := range tokens {
 | |
| 		if math.IsInf(float64(tok.value), -1) {
 | |
| 			infCount++
 | |
| 		} else {
 | |
| 			nonInfCount++
 | |
| 		}
 | |
| 	}
 | |
| 	if nonInfCount == 0 {
 | |
| 		t.Error("expected at least one non -inf token after grammar application, got none")
 | |
| 	}
 | |
| 	if infCount == 0 {
 | |
| 		t.Error("expected some -inf tokens after grammar application, got none")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func BenchmarkSample(b *testing.B) {
 | |
| 	samplers := map[string]Sampler{
 | |
| 		"Greedy":   NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
 | |
| 		"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
 | |
| 	}
 | |
| 
 | |
| 	// Generate random logits for benchmarking
 | |
| 	logits := make([]float32, 1<<16)
 | |
| 	for i := range logits {
 | |
| 		logits[i] = rand.Float32()
 | |
| 	}
 | |
| 
 | |
| 	for name, s := range samplers {
 | |
| 		b.Run(name, func(b *testing.B) {
 | |
| 			b.ResetTimer()
 | |
| 			for b.Loop() {
 | |
| 				if _, err := s.Sample(logits); err != nil {
 | |
| 					b.Fatalf("error sampling: %v", err)
 | |
| 				}
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 |