| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | package model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"cmp" | 
					
						
							| 
									
										
										
										
											2025-05-13 02:43:00 +08:00
										 |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2025-05-29 10:38:52 +08:00
										 |  |  | 	"fmt" | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	"iter" | 
					
						
							|  |  |  | 	"log/slog" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/dlclark/regexp2" | 
					
						
							|  |  |  | 	heap "github.com/emirpasic/gods/v2/trees/binaryheap" | 
					
						
							| 
									
										
										
										
											2025-05-13 02:43:00 +08:00
										 |  |  | 	"github.com/ollama/ollama/logutil" | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type BytePairEncoding struct { | 
					
						
							|  |  |  | 	pre   *regexp2.Regexp | 
					
						
							|  |  |  | 	vocab *Vocabulary | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | var _ TextProcessor = (*BytePairEncoding)(nil) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { | 
					
						
							|  |  |  | 	return BytePairEncoding{ | 
					
						
							| 
									
										
										
										
											2025-08-06 03:21:16 +08:00
										 |  |  | 		pre:   regexp2.MustCompile(pre, regexp2.None), | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 		vocab: vocab, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-25 02:51:19 +08:00
										 |  |  | func (bpe BytePairEncoding) Vocabulary() *Vocabulary { | 
					
						
							|  |  |  | 	return bpe.vocab | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-04 11:12:04 +08:00
										 |  |  | func (bpe BytePairEncoding) Is(id int32, special Special) bool { | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	return bpe.vocab.Is(id, special) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { | 
					
						
							|  |  |  | 	return func(yield func(string) bool) { | 
					
						
							|  |  |  | 		for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { | 
					
						
							|  |  |  | 			if !yield(m.String()) { | 
					
						
							|  |  |  | 				break | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // fragment is a string fragment and their corresponding token IDs
 | 
					
						
							|  |  |  | type fragment struct { | 
					
						
							|  |  |  | 	value string | 
					
						
							|  |  |  | 	ids   []int32 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // pair is a pair of runes and its rank
 | 
					
						
							|  |  |  | type pair struct { | 
					
						
							|  |  |  | 	a, b  int | 
					
						
							|  |  |  | 	rank  int | 
					
						
							|  |  |  | 	value string | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type merge struct { | 
					
						
							|  |  |  | 	p, n  int | 
					
						
							|  |  |  | 	runes []rune | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-06 05:27:53 +08:00
										 |  |  | func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	fragments := []fragment{{value: s}} | 
					
						
							|  |  |  | 	for _, special := range bpe.vocab.SpecialVocabulary() { | 
					
						
							|  |  |  | 		// TODO: process special tokens concurrently
 | 
					
						
							|  |  |  | 		id := bpe.vocab.Encode(special) | 
					
						
							|  |  |  | 		for i := 0; i < len(fragments); i++ { | 
					
						
							|  |  |  | 			frag := fragments[i] | 
					
						
							|  |  |  | 			if len(frag.ids) > 0 { | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			var middle []fragment | 
					
						
							|  |  |  | 			switch i := strings.Index(frag.value, special); { | 
					
						
							|  |  |  | 			case i < 0: | 
					
						
							|  |  |  | 				middle = append(middle, frag) | 
					
						
							|  |  |  | 			case i > 0: | 
					
						
							|  |  |  | 				middle = append(middle, fragment{value: frag.value[:i]}) | 
					
						
							|  |  |  | 				fallthrough | 
					
						
							|  |  |  | 			default: | 
					
						
							|  |  |  | 				middle = append(middle, fragment{value: special, ids: []int32{id}}) | 
					
						
							|  |  |  | 				if rest := frag.value[i+len(special):]; rest != "" { | 
					
						
							|  |  |  | 					middle = append(middle, fragment{value: rest}) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var ids []int32 | 
					
						
							|  |  |  | 	for _, frag := range fragments { | 
					
						
							|  |  |  | 		if len(frag.ids) > 0 { | 
					
						
							|  |  |  | 			ids = append(ids, frag.ids...) | 
					
						
							|  |  |  | 			continue | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		for split := range bpe.split(frag.value) { | 
					
						
							|  |  |  | 			// TODO: process splits concurrently
 | 
					
						
							|  |  |  | 			var sb strings.Builder | 
					
						
							|  |  |  | 			for _, b := range []byte(split) { | 
					
						
							|  |  |  | 				r := rune(b) | 
					
						
							|  |  |  | 				switch { | 
					
						
							|  |  |  | 				case r == 0x00ad: | 
					
						
							|  |  |  | 					r = 0x0143 | 
					
						
							|  |  |  | 				case r <= 0x0020: | 
					
						
							|  |  |  | 					r = r + 0x0100 | 
					
						
							|  |  |  | 				case r >= 0x007e && r <= 0x00a0: | 
					
						
							|  |  |  | 					r = r + 0x00a2 | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				sb.WriteRune(r) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// short circuit if the fragment is in the vocabulary
 | 
					
						
							|  |  |  | 			if id := bpe.vocab.Encode(sb.String()); id >= 0 { | 
					
						
							|  |  |  | 				ids = append(ids, id) | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			runes := []rune(sb.String()) | 
					
						
							|  |  |  | 			merges := make([]merge, len(runes)) | 
					
						
							|  |  |  | 			for r := range runes { | 
					
						
							|  |  |  | 				merges[r] = merge{ | 
					
						
							|  |  |  | 					p:     r - 1, | 
					
						
							|  |  |  | 					n:     r + 1, | 
					
						
							|  |  |  | 					runes: []rune{runes[r]}, | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			pairwise := func(a, b int) *pair { | 
					
						
							|  |  |  | 				if a < 0 || b >= len(runes) { | 
					
						
							|  |  |  | 					return nil | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				left, right := string(merges[a].runes), string(merges[b].runes) | 
					
						
							|  |  |  | 				rank := bpe.vocab.Merge(left, right) | 
					
						
							|  |  |  | 				if rank < 0 { | 
					
						
							|  |  |  | 					return nil | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				return &pair{ | 
					
						
							|  |  |  | 					a:     a, | 
					
						
							|  |  |  | 					b:     b, | 
					
						
							|  |  |  | 					rank:  rank, | 
					
						
							|  |  |  | 					value: left + right, | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			pairs := heap.NewWith(func(i, j *pair) int { | 
					
						
							|  |  |  | 				return cmp.Compare(i.rank, j.rank) | 
					
						
							|  |  |  | 			}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			for i := range len(runes) - 1 { | 
					
						
							|  |  |  | 				if pair := pairwise(i, i+1); pair != nil { | 
					
						
							|  |  |  | 					pairs.Push(pair) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			for !pairs.Empty() { | 
					
						
							|  |  |  | 				pair, _ := pairs.Pop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				left, right := merges[pair.a], merges[pair.b] | 
					
						
							|  |  |  | 				if len(left.runes) == 0 || len(right.runes) == 0 || | 
					
						
							|  |  |  | 					string(left.runes)+string(right.runes) != pair.value { | 
					
						
							|  |  |  | 					continue | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-03-15 07:56:32 +08:00
										 |  |  | 				if id := bpe.vocab.Encode(pair.value); id < 0 { | 
					
						
							|  |  |  | 					continue | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 				merges[pair.a].runes = append(left.runes, right.runes...) | 
					
						
							|  |  |  | 				merges[pair.b].runes = nil | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				merges[pair.a].n = right.n | 
					
						
							|  |  |  | 				if right.n < len(merges) { | 
					
						
							|  |  |  | 					merges[right.n].p = pair.a | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { | 
					
						
							|  |  |  | 					pairs.Push(pair) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { | 
					
						
							|  |  |  | 					pairs.Push(pair) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			for _, merge := range merges { | 
					
						
							|  |  |  | 				if len(merge.runes) > 0 { | 
					
						
							|  |  |  | 					// TODO: handle the edge case where the rune isn't in the vocabulary
 | 
					
						
							|  |  |  | 					if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { | 
					
						
							|  |  |  | 						ids = append(ids, id) | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-17 04:40:23 +08:00
										 |  |  | 	slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) | 
					
						
							| 
									
										
										
										
											2025-02-27 03:20:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-17 04:40:23 +08:00
										 |  |  | 	if addSpecial && len(ids) > 0 { | 
					
						
							|  |  |  | 		ids = bpe.vocab.addSpecials(ids) | 
					
						
							| 
									
										
										
										
											2025-02-27 03:20:51 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	return ids, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-29 10:38:52 +08:00
										 |  |  | type lazyIdsString struct { | 
					
						
							|  |  |  | 	ids []int32 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (l lazyIdsString) LogValue() slog.Value { | 
					
						
							|  |  |  | 	return slog.AnyValue(fmt.Sprint(l.ids)) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { | 
					
						
							|  |  |  | 	var sb strings.Builder | 
					
						
							|  |  |  | 	for _, id := range ids { | 
					
						
							|  |  |  | 		for _, r := range bpe.vocab.Decode(id) { | 
					
						
							|  |  |  | 			switch { | 
					
						
							|  |  |  | 			case r == 0x0100: | 
					
						
							|  |  |  | 				// this produces 0x00 aka NULL
 | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			case r == 0x0143: | 
					
						
							|  |  |  | 				r = 0x00ad | 
					
						
							|  |  |  | 			case r > 0x0100 && r <= 0x0120: | 
					
						
							|  |  |  | 				r = r - 0x0100 | 
					
						
							|  |  |  | 			case r > 0x0120 && r <= 0x0142: | 
					
						
							|  |  |  | 				r = r - 0x00a2 | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// NOTE: not using WriteRune here because it writes the UTF-8
 | 
					
						
							|  |  |  | 			// encoding of the rune which is _not_ what we want
 | 
					
						
							|  |  |  | 			if err := sb.WriteByte(byte(r)); err != nil { | 
					
						
							|  |  |  | 				return "", err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-29 10:38:52 +08:00
										 |  |  | 	slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) | 
					
						
							| 
									
										
										
										
											2025-02-14 08:31:21 +08:00
										 |  |  | 	return sb.String(), nil | 
					
						
							|  |  |  | } |