model: implement bert in ollama engine (#9080)

* fix truncate

* s/SentencePieceModel/SentencePiece/

* bert

* wordpiece

* refactor pooling

* more tokenizers

* normalize embeddings
This commit is contained in:
Michael Yang 2025-09-15 15:35:59 -07:00 committed by GitHub
parent 6f7117145f
commit 3f6642f6fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 490 additions and 40 deletions

View File

@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"`
normalizeEmbeddings bool
PoolingType uint32 PoolingType uint32
} }
@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string var pooling string
for _, m := range modules { for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" { switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path pooling = m.Path
break case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
} }
} }
@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert" kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType kv["bert.pooling_type"] = p.PoolingType
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)

View File

@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor Scale(ctx Context, s float64) Tensor

View File

@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil { if w != nil {

36
ml/nn/pooling/pooling.go Normal file
View File

@ -0,0 +1,36 @@
package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
TypeRank
TypeUnknown = 0xFFFFFFFE
TypeUnspecified = 0xFFFFFFFF
)
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
switch poolingType {
case TypeNone:
return hiddenStates
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
panic("not implemented")
case TypeRank:
panic("not implemented")
default:
panic("not implemented")
}
}

View File

@ -24,7 +24,11 @@ import (
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
var ErrNoVisionModel = errors.New("this model is missing data required for image input") var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem() vv = vv.Elem()
} }
vv = vv.Elem() vv = reflect.Indirect(vv)
if v.IsNil() { if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem() vv = reflect.New(v.Type().Elem()).Elem()
} }

181
model/models/bert/model.go Normal file
View File

@ -0,0 +1,181 @@
package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
// Attention
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// MLP
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := a.Query.Forward(ctx, hiddenStates)
if a.QueryNorm != nil {
query = a.QueryNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key := a.Key.Forward(ctx, hiddenStates)
if a.KeyNorm != nil {
key = a.KeyNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
poolingType pooling.Type
eps float32
normalize bool
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func New(c fs.Config) (model.Model, error) {
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
//nolint:misspell
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
// support it for compatibility.
c.Uint("tokenizer.ggml.seperator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", true),
},
}, nil
}
func init() {
model.Register("bert", New)
model.Register("bert_embed", New)
}

View File

@ -24,7 +24,7 @@ type Options struct {
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"` TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"` Layers []Layer `gguf:"blk"`
@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@ -1,48 +1,38 @@
package gemma3 package gemma3
import ( import (
"errors"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
type embedModel struct { type embedModel struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*TextModel *TextModel
PoolingType uint32 poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"` Dense [2]*nn.Linear `gguf:"dense"`
} }
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
for _, dense := range m.Dense { for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates) hiddenStates = dense.Forward(ctx, hiddenStates)
} }
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil return hiddenStates, nil
} }
func newEmbedModel(c fs.Config) (model.Model, error) { func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{ m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
@ -60,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
}, },
), ),
TextModel: newTextModel(c), TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0), poolingType: pooling.Type(c.Uint("pooling_type", 0)),
} }
m.Cache = kvcache.NewWrapperCache( m.Cache = kvcache.NewWrapperCache(

View File

@ -16,7 +16,7 @@ import (
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*VisionModel `gguf:"v"` *VisionModel `gguf:"v"`
*TextModel *TextModel
@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@ -10,7 +10,7 @@ import (
type Model struct { type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePiece
*TextModel *TextModel
} }
@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
TextModel: newTextModel(c), TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel( SentencePiece: model.NewSentencePiece(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),

View File

@ -1,6 +1,7 @@
package models package models
import ( import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/gemma3n"

View File

@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁" const spmWhitespaceSep = "▁"
type SentencePieceModel struct { type SentencePiece struct {
maxTokenLen int maxTokenLen int
vocab *Vocabulary vocab *Vocabulary
} }
var _ TextProcessor = (*SentencePieceModel)(nil) var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePieceModel) Vocabulary() *Vocabulary { func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab return spm.vocab
} }
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{} counter := map[int]int{}
@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen) "max token len", maxTokenLen)
return SentencePieceModel{ return SentencePiece{
maxTokenLen: maxTokenLen, maxTokenLen: maxTokenLen,
vocab: vocab, vocab: vocab,
} }
} }
func (spm SentencePieceModel) Is(id int32, special Special) bool { func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special) return spm.vocab.Is(id, special)
} }
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() { for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special) id := spm.vocab.Encode(special)
@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item return item
} }
func (spm SentencePieceModel) Decode(ids []int32) (string, error) { func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder var sb strings.Builder
for _, id := range ids { for _, id := range ids {
data := spm.vocab.Decode(id) data := spm.vocab.Decode(id)

View File

@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece" "github.com/ollama/ollama/convert/sentencepiece"
) )
func loadSentencePieceVocab(t *testing.T) SentencePieceModel { func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper() t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
} }
} }
return NewSentencePieceModel(&v) return NewSentencePiece(&v)
} }
func TestSentencePieceEncode(t *testing.T) { func TestSentencePieceEncode(t *testing.T) {
@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
}) })
} }
func TestSentencePieceModelDecodeByteTokens(t *testing.T) { func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{ vocab := &Vocabulary{
Values: []string{ Values: []string{
"normal", "normal",
@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0}, Scores: []float32{0, 0, 0, 0, 0},
} }
spm := NewSentencePieceModel(vocab) spm := NewSentencePiece(vocab)
tests := []struct { tests := []struct {
name string name string

167
model/wordpiece.go Normal file
View File

@ -0,0 +1,167 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial && len(ids) > 0 {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary) WordPiece {
return WordPiece{
vocab: vocab,
}
}

51
model/wordpiece_test.go Normal file
View File

@ -0,0 +1,51 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
})
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@ -488,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
truncate := true truncate := true
if req.Truncate != nil && !*req.Truncate { if req.Truncate != nil && !*req.Truncate {
truncate = false truncate = false
} }
@ -555,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
tokens = tokens[:ctxLen] tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens) s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})