| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | package sample | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | 	"encoding/json" | 
					
						
							| 
									
										
										
										
											2025-03-21 02:11:18 +08:00
										 |  |  | 	"math" | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	"math/rand/v2" | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | 	"os" | 
					
						
							|  |  |  | 	"path/filepath" | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	"testing" | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/ollama/ollama/model" | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestWeighted(t *testing.T) { | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 	logits := []float32{-10, 3, -10, -10} | 
					
						
							| 
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 |  |  | 	sampler := NewSampler(0, 0, 0, 0, 0, nil) | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 	got, err := sampler.Sample(logits) | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Error(err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	want := int32(1) | 
					
						
							|  |  |  | 	if want != got { | 
					
						
							|  |  |  | 		t.Errorf("index mismatch: want %d, got %d", want, got) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 	logits = []float32{-100, -10, 0, 10} | 
					
						
							| 
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 |  |  | 	sampler = NewSampler(0, 0, 0, 0, 0, nil) | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 	got, err = sampler.Sample(logits) | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Error(err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 	want = int32(3) // Should pick highest probability with this r value
 | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	if want != got { | 
					
						
							|  |  |  | 		t.Errorf("index mismatch: want %d, got %d", want, got) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-03-21 02:11:18 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Test very high p
 | 
					
						
							|  |  |  | 	logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1} | 
					
						
							|  |  |  | 	// Use extremely small topP to filter out all tokens
 | 
					
						
							|  |  |  | 	sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil) | 
					
						
							|  |  |  | 	got, err = sampler.Sample(logits) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Error(err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	// Should get the token with the highest logit
 | 
					
						
							|  |  |  | 	want = int32(0) | 
					
						
							|  |  |  | 	if want != got { | 
					
						
							|  |  |  | 		t.Errorf("index mismatch: want %d, got %d", want, got) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())} | 
					
						
							|  |  |  | 	sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil) | 
					
						
							|  |  |  | 	got, err = sampler.Sample(logits) | 
					
						
							|  |  |  | 	if err == nil { | 
					
						
							|  |  |  | 		t.Errorf("expected error, got %d", got) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | func modelHelper(t testing.TB) model.BytePairEncoding { | 
					
						
							|  |  |  | 	t.Helper() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json")) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer f.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	vocab := make(map[string]int32) | 
					
						
							|  |  |  | 	if err := json.NewDecoder(f).Decode(&vocab); err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	tokens := make([]string, len(vocab)) | 
					
						
							|  |  |  | 	for token, id := range vocab { | 
					
						
							|  |  |  | 		tokens[id] = token | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	merges := make([]string, 0, 1) | 
					
						
							|  |  |  | 	// Only need vocab for Grammar Test
 | 
					
						
							|  |  |  | 	return model.NewBytePairEncoding( | 
					
						
							|  |  |  | 		&model.Vocabulary{ | 
					
						
							|  |  |  | 			Values: tokens, | 
					
						
							| 
									
										
										
										
											2025-04-26 07:15:08 +08:00
										 |  |  | 			Types:  make([]int32, len(vocab)), | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | 			Merges: merges, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestGrammar(t *testing.T) { | 
					
						
							|  |  |  | 	tokenizer := modelHelper(t) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	grammarJSON := ` | 
					
						
							|  |  |  | 	root   ::= object | 
					
						
							|  |  |  | 	value  ::= object | array | string | number | ("true" | "false" | "null") ws | 
					
						
							|  |  |  | 	object ::= | 
					
						
							|  |  |  | 	"{" ws ( | 
					
						
							|  |  |  | 				string ":" ws value | 
					
						
							|  |  |  | 		("," ws string ":" ws value)* | 
					
						
							|  |  |  | 	)? "}" ws | 
					
						
							|  |  |  | 	array  ::= | 
					
						
							|  |  |  | 	"[" ws ( | 
					
						
							|  |  |  | 				value | 
					
						
							|  |  |  | 		("," ws value)* | 
					
						
							|  |  |  | 	)? "]" ws | 
					
						
							|  |  |  | 	string ::= | 
					
						
							|  |  |  | 	"\"" ( | 
					
						
							|  |  |  | 		[^"\\\x7F\x00-\x1F] | | 
					
						
							|  |  |  | 		"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes | 
					
						
							|  |  |  | 	)* "\"" ws | 
					
						
							|  |  |  | 	number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws | 
					
						
							|  |  |  | 	# Optional space: by convention, applied in this grammar after literal chars when allowed | 
					
						
							|  |  |  | 	ws ::= ([ \t\n] ws)? | 
					
						
							|  |  |  | 	` | 
					
						
							|  |  |  | 	grammar, err := NewGrammarSampler(tokenizer, grammarJSON) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer grammar.Free() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	logits := make([]float32, len(tokenizer.Vocabulary().Values)) | 
					
						
							|  |  |  | 	for i := range logits { | 
					
						
							|  |  |  | 		logits[i] = rand.Float32() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	tokens := make([]token, len(logits)) | 
					
						
							|  |  |  | 	for i := range tokens { | 
					
						
							|  |  |  | 		tokens[i].id = int32(i) | 
					
						
							|  |  |  | 		tokens[i].value = logits[i] | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	grammar.Apply(tokens) | 
					
						
							|  |  |  | 	nonInfCount := 0 | 
					
						
							|  |  |  | 	infCount := 0 | 
					
						
							|  |  |  | 	for _, tok := range tokens { | 
					
						
							|  |  |  | 		if math.IsInf(float64(tok.value), -1) { | 
					
						
							|  |  |  | 			infCount++ | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			nonInfCount++ | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if nonInfCount == 0 { | 
					
						
							|  |  |  | 		t.Error("expected at least one non -inf token after grammar application, got none") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if infCount == 0 { | 
					
						
							|  |  |  | 		t.Error("expected some -inf tokens after grammar application, got none") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | func BenchmarkSample(b *testing.B) { | 
					
						
							|  |  |  | 	samplers := map[string]Sampler{ | 
					
						
							| 
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 |  |  | 		"Greedy":   NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
 | 
					
						
							|  |  |  | 		"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 	// Generate random logits for benchmarking
 | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 	logits := make([]float32, 1<<16) | 
					
						
							|  |  |  | 	for i := range logits { | 
					
						
							|  |  |  | 		logits[i] = rand.Float32() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for name, s := range samplers { | 
					
						
							|  |  |  | 		b.Run(name, func(b *testing.B) { | 
					
						
							|  |  |  | 			b.ResetTimer() | 
					
						
							| 
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 |  |  | 			for b.Loop() { | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 				if _, err := s.Sample(logits); err != nil { | 
					
						
							| 
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 |  |  | 					b.Fatalf("error sampling: %v", err) | 
					
						
							| 
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |