2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								package sample
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import (
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									"container/heap"
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"math"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									"slices"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								type tokenHeap []token
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func (h tokenHeap) Len() int           { return len(h) }
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 23:58:08 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func (h tokenHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func (h *tokenHeap) Push(x any) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									*h = append(*h, x.(token))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								func (h *tokenHeap) Pop() any {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									old := *h
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									n := len(old)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									x := old[n-1]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									*h = old[0 : n-1]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									return x
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-14 00:53:27 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								// temperature applies scaling to the logits
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func temperature(ts []token, temp float32) {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-14 00:53:27 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// Ensure temperature clipping near 0 to avoid numerical instability
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									temp = max(temp, 1e-7)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									for i := range ts {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										ts[i].value = ts[i].value / temp
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								// softmax applies normalization to the logits
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func softmax(ts []token) {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-11 05:43:53 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// Find max logit for numerical stability
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									maxLogit := float32(math.Inf(-1))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									for _, t := range ts {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										if t.value > maxLogit {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											maxLogit = t.value
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-14 00:53:27 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// Compute exp(x - max)
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									var sum float32
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									for i, v := range ts {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-14 00:53:27 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										sum += ts[i].value
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-14 00:53:27 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// exp(x - max) / sum(exp(x - max))
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									for i := range ts {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										ts[i].value /= sum
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								// topK limits the number of tokens considered to the k highest logits
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func topK(ts []token, k int) []token {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-13 01:40:25 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									if k >= len(ts) || k <= 0 {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										slices.SortFunc(ts, func(a, b token) int {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											switch {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											case a.value < b.value:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
												return 1
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											case a.value > b.value:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
												return -1
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											default:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
												return 0
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										})
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										return ts
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// Initialize min-heap with first k elements
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									h := make(tokenHeap, k)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									copy(h, ts[:k])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									heap.Init(&h)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									// Process remaining elements
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									for i := k; i < len(ts); i++ {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										if ts[i].value > h[0].value {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											heap.Pop(&h)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											heap.Push(&h, ts[i])
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// Convert heap to sorted slice in descending order
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 23:58:08 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									result := make([]token, len(h))
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									for i := k - 1; i >= 0; i-- {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										result[i] = heap.Pop(&h).(token)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-12 12:45:41 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									return result
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								// topP limits tokens to those with cumulative probability p
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								// requires ts to be sorted in descending order of probabilities
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func topP(ts []token, p float32) []token {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									if p == 1.0 {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										return ts
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									// Find cutoff index where cumulative sum exceeds p
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									var sum float32
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									for i, t := range ts {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										sum += t.value
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										if sum > float32(p) {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
											return ts[:i+1]
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									return ts
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								// minP filters tokens with probabilities >= p * max_prob
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								// requires ts to be sorted in descending order of probabilities
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-10 23:17:39 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								func minP(ts []token, p float32) []token {
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									maxProb := ts[0].value
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									threshold := maxProb * p
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-18 02:24:18 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									for i, t := range ts {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
										if t.value < threshold {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
											return ts[:i]
							 | 
						
					
						
							
								
									
										
										
										
											2025-02-25 09:19:01 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
										}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
									}
							 | 
						
					
						
							
								
									
										
										
										
											2025-03-08 04:37:48 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
									return ts
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 |