restructure

image processing

Update model.go

Update model.go

Update model.go

no projector

no projector

vision model scaffold

...

...

wip

...

rebase

fix patch merger

tidy

...

Update model_vision.go

server: do not attempt to parse offset file as gguf

This logic was causing issues for me when importing a gguf that had some padding at the end of the file. The valid gguf would be read, but then it would try to read the offset as a different gguf file. This does not seem right.

Update process_image_test.go

apply norm

prompt processing

prompt processing

fix post tokenize

fix gguf padding + populate the split patch embeddings

...

...

another shot at patch embeddings

...

patch embedding

Update model_vision.go

split pixels
This commit is contained in:
Bruce MacDonald 2025-04-02 10:41:51 -07:00
parent 7e55273f89
commit 57823c39b5
17 changed files with 1212 additions and 214 deletions

View File

@ -198,6 +198,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
conv = &phi3Model{}
case "Qwen2ForCausalLM":
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25vlModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

188
convert/convert_qwen25vl.go Normal file
View File

@ -0,0 +1,188 @@
package convert
import (
"bytes"
"encoding/binary"
"io"
"log/slog"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"
)
type qwen25vlModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
HiddenLayers uint32 `json:"num_hidden_layers"`
RopeTheta float32 `json:"rope_theta"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
VisionModel struct {
PatchSize uint32 `json:"patch_size"`
//HeadDim uint32 `json:"num_heads"`
//RopeTheta float32 `json:"rope_theta"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
WindowSize uint32 `json:"window_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25vlModel)(nil)
func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
kv["qwen25vl.block_count"] = q.HiddenLayers
kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings
kv["qwen25vl.embedding_length"] = q.HiddenSize
kv["qwen25vl.feed_forward_length"] = q.IntermediateSize
kv["qwen25vl.attention.head_count"] = q.NumAttentionHeads
kv["qwen25vl.attention.head_count_kv"] = q.NumKeyValueHeads
kv["qwen25vl.rope.freq_base"] = q.RopeTheta
kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
return kv
}
func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") {
// var buf bytes.Buffer
// if _, err := t.WriteTo(&buf); err != nil {
// panic(err)
// }
// newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape())
// out = append(out, newTensors...)
// } else if strings.HasPrefix(t.Name(), "v.blk.") {
// skip
} else {
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25vlModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"visual.blocks", "v.blk",
"input_layernorm", "attn_norm",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.q_proj", "attn_q",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
"model.norm", "output_norm",
}
}
func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor {
slog.Debug("patch stuff", "kind", kind, "shape", shape)
if kind != tensorKindF16 {
panic("tensor is of wrong type")
}
if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) {
panic("wrong sized tensor")
}
// determine the size of the tensor based on its shape
shapeToSize := func(s []int) int {
r := 1
for _, n := range s {
r *= int(n)
}
return r
}
// tensor.WithShape() wants []int
intShape := make([]int, len(shape))
for i, v := range shape {
intShape[i] = int(v)
}
u16s := make([]uint16, shapeToSize(intShape))
if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil {
panic("bad read")
}
f32s := make([]float32, len(u16s))
for i := range u16s {
f32s[i] = float16.Frombits(u16s[i]).Float32()
}
newTensors := []ggml.Tensor{}
getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed {
slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape)
n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s))
t, err := n.Slice(s...)
if err != nil {
panic(err)
}
ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0)
if err != nil {
panic(err)
}
slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2])
var f16s patchEmbed
for _, row := range ts {
for _, col := range row {
f16s = append(f16s, float16.Fromfloat32(col).Bits())
}
}
return f16s
}
p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil})
newTensors = append(newTensors, ggml.Tensor{
Name: "v.patch_embed.0.weight",
Kind: kind,
Shape: append(shape[:2], shape[3:]...),
WriterTo: p,
})
p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil})
newTensors = append(newTensors, ggml.Tensor{
Name: "v.patch_embed.1.weight",
Kind: kind,
Shape: append(shape[:2], shape[3:]...),
WriterTo: p,
})
return newTensors
}
type patchEmbed []uint16
func (t patchEmbed) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, t)
return 0, err
}

View File

@ -541,6 +541,7 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
})
var s uint64
var alignment int64 = 32
for _, t := range ts {
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
if err := ggufWriteTensorInfo(ws, t); err != nil {
@ -655,5 +656,9 @@ func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
}
func ggufPadding(offset, align int64) int64 {
// if we already fit perfectly onto a 16 byte boundary, don't bother padding
if ((align-offset%align)%align)%16 == 0 {
return 0
}
return (align - offset%align) % align
}

View File

@ -551,10 +551,16 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, co
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@ -191,6 +191,7 @@ type Tensor interface {
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor
Sin(ctx Context) Tensor
Cos(ctx Context) Tensor

View File

@ -1042,15 +1042,6 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
}
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
// RoPE applies Rotary Position Embeddings to the tensor
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil {
@ -1066,21 +1057,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
}
// Map Go RopeType to C implementation constants
var ropeTypeC C.int
switch config.Type {
case ml.RopeTypeNormal:
ropeTypeC = ropeTypeNorm
case ml.RopeTypeNeox:
ropeTypeC = ropeTypeNeox
case ml.RopeTypeMRoPE:
ropeTypeC = ropeTypeMrope
case ml.RopeTypeVision:
ropeTypeC = ropeTypeVision
default:
ropeTypeC = ropeTypeNorm
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
@ -1089,7 +1065,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(config.Dim),
ropeTypeC,
ropeTypeToC(config.Type),
C.int(config.YarnCtxTrain),
C.float(config.Base),
C.float(config.Scale),
@ -1107,6 +1083,60 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
}
}
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(config.Dim),
(*C.int)(unsafe.Pointer(&sections[0])),
ropeTypeToC(config.Type),
C.int(config.YarnCtxTrain),
C.float(config.Base),
C.float(config.Scale),
C.float(config.YarnExtFactor),
C.float(config.YarnAttnFactor),
C.float(config.YarnBetaFast),
C.float(config.YarnBetaSlow),
),
}
}
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
func ropeTypeToC(ropeType ml.RopeType) C.int {
switch ropeType {
case ml.RopeTypeNormal:
return ropeTypeNorm
case ml.RopeTypeNeox:
return ropeTypeNeox
case ml.RopeTypeMRoPE:
return ropeTypeMrope
case ml.RopeTypeVision:
return ropeTypeVision
default:
return ropeTypeNorm
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{

View File

@ -0,0 +1,51 @@
package model
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/ml"
)
func setup(t *testing.T) ml.Backend {
home, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
models := filepath.Join(home, ".ollama", "models")
b, err := New(context.TODO(), filepath.Join(models, "blobs", "sha256-667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"), ml.BackendParams{NumGPULayers: 99})
if err != nil {
t.Fatal(err)
}
return b
}
func TestUnfoldConv(t *testing.T) {
b := setup(t)
ctx := b.NewContext().Input()
t.Cleanup(func() { ctx.Close() })
tiles, channels, height, width := 5, 3, 336, 336
patchSize := 14
tt := ctx.Arange(0, float32(tiles*channels*height*width), 1, ml.DTypeF32).Reshape(ctx, width, height, channels, tiles)
t.Log("tt", tt.Shape())
t.Log(ml.Dump(ctx, tt))
kernel := ctx.Empty(ml.DTypeF32, patchSize, patchSize, channels)
t.Log("kernel", kernel.Shape())
t.Log(ml.Dump(ctx, kernel))
tt = kernel.IM2Col(ctx, tt, patchSize, patchSize, 0, 0, 1, 1)
t.Log("tt", tt.Shape())
t.Log(ml.Dump(ctx, tt))
tt = tt.Reshape(ctx, tt.Dim(0), tt.Dim(1)*tt.Dim(2), tt.Dim(3))
t.Log("tt", tt.Shape())
t.Log(ml.Dump(ctx, tt))
}

View File

@ -57,7 +57,7 @@ func newTextModel(c fs.Config) *TextModel {
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),

View File

@ -17,6 +17,7 @@ type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
ropeConfig ml.RoPEConfig
}
type TextModel struct {
@ -40,7 +41,6 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil
}
type MLP struct {
@ -167,9 +167,13 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
},
}

View File

@ -1,10 +1,11 @@
package qwen25vl
import (
"bytes"
"fmt"
"math"
"strings"
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@ -12,147 +13,151 @@ import (
"github.com/ollama/ollama/model/input"
)
type Options struct {
ctxLen, hiddenSize, numHeads, numKVHeads int
eps float32
ropeConfig ml.RoPEConfig
}
type Model struct {
model.Base
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v,vision"`
*PatchMerger `gguf:"mm"`
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
ImageProcessor
}
func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
type PatchMerger struct {
MLPLayer1 *nn.Linear `gguf:"0"`
MLPLayer2 *nn.Linear `gguf:"2"`
}
// Forward computes patch merging for the vision model
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
// Get dimensions
hiddenSize := visionOutputs.Dim(0)
numPositions := visionOutputs.Dim(1)
batchSize := visionOutputs.Dim(2)
reshaped := visionOutputs.Reshape(ctx, hiddenSize*4, numPositions/4, batchSize)
// Apply first linear layer (mm_0_w, mm_0_b)
hidden := pm.MLPLayer1.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
// Apply second linear layer (mm_1_w, mm_1_b)
output := pm.MLPLayer2.Forward(ctx, activated)
return output
}
func New(c fs.Config) (model.Model, error) {
m := &Model{
TextModel: NewTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
ctxLen: int(c.Uint("context_length")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count", 128),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 32768))),
},
},
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
type imageFeatures struct {
Tensor ml.Tensor
GridT int
GridH int
GridW int
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
hiddenState = hiddenState.Add(ctx, residual)
// Feed-forward branch with residual connection
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
// Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := gridT * gridH * gridW
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps)
return &imageFeatures{
Tensor: visionOutputs,
GridT: gridT,
GridH: gridH,
GridW: gridW,
}, nil
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
// Get image token IDs from config
imageToken := 151655
visionStartToken := 151652
visionEndToken := 151653
// Get merge size from config
mergeSize := m.ImageProcessor.mergeSize
for _, inp := range inputs {
if inp.Multimodal == nil {
// If not a multimodal input, add it to the result unchanged
result = append(result, inp)
} else {
// This is an image token with multimodal data
features := inp.Multimodal.(*imageFeatures)
// Get grid dimensions from the features
gridT := features.GridT
gridH := features.GridH
gridW := features.GridW
// Calculate tokens per grid based on grid dimensions
mergeLength := mergeSize * mergeSize
gridProduct := gridT * gridH * gridW
tokensPerGrid := gridProduct / mergeLength
// First add the vision start token
result = append(result, input.Input{Token: int32(visionStartToken)})
// Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{
Token: int32(imageToken),
Multimodal: features.Tensor,
MultimodalHash: inp.MultimodalHash,
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
for range tokensPerGrid - 1 {
result = append(result, input.Input{Token: int32(imageToken)})
}
result = append(result, input.Input{Token: int32(visionEndToken)})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// Convert input tokens and positions to tensors
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
@ -163,25 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return nil, err
}
// Initial token embedding
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
// Process through transformer layers
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
}
func init() {
model.Register("qwen25vl", New)
model.Register("qwen2vl", New)
}

View File

@ -0,0 +1,59 @@
package qwen25vl
import (
"testing"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/model/input"
)
func TestPostTokenize(t *testing.T) {
// Set up test inputs
model := &Model{}
mockHash := uint64(12345678)
inputs := []input.Input{
{Token: 123}, // Regular token
{Token: 456}, // Regular token
{Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token
{Token: 789}, // Regular token
}
// Run the function being tested
result, err := model.PostTokenize(inputs)
if err != nil {
t.Fatalf("PostTokenize returned error: %v", err)
}
// Verify the actual length first
expectedLength := 21
if len(result) != expectedLength {
t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength)
}
// Check key positions only
checkPositions := map[int]int32{
0: 123, // First regular token
1: 456, // Second regular token
2: 151652, // Vision start token
4: 151655, // First placeholder token
19: 151653, // Vision end token
20: 789, // Final regular token
}
for pos, expectedToken := range checkPositions {
if pos >= len(result) {
t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result))
continue
}
if result[pos].Token != expectedToken {
t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token)
}
}
// Check multimodal data is preserved
if result[3].MultimodalHash != mockHash {
t.Errorf("Multimodal hash not preserved: got %d, expected %d",
result[3].MultimodalHash, mockHash)
}
}

View File

@ -0,0 +1,165 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
ctxLen, hiddenSize, numHeads, numKVHeads int
eps float32
ropeConfig ml.RoPEConfig
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func NewTextModel(c fs.Config) *TextModel {
m := TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
ctxLen: int(c.Uint("context_length")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count", 128),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))),
},
},
}
return &m
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
// Feed-forward branch with residual connection
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
// Process through transformer layers
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}

View File

@ -0,0 +1,260 @@
package qwen25vl
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
// VisionSelfAttention implements self-attention for the Qwen vision model
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
// Forward computes self-attention for the vision model
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
config := ml.RoPEConfig{
Dim: uint32(opts.headDim / 2),
Type: ml.RopeTypeMRoPE,
Base: opts.ropeTheta,
Scale: 1.0,
YarnConfig: ml.DefaultYarnConfig(128000),
}
query = query.RoPEMulti(
ctx,
positionIDs,
nil,
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
config,
)
key = key.RoPEMulti(
ctx,
positionIDs,
nil,
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
config,
)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
attention := nn.Attention(ctx, query, key, value, scale, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
// VisionMLP implements the MLP for the Qwen vision model
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
// Forward computes the MLP for the vision model
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using GEGLU activation: (Gate * Up) * GELU(Gate)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput)
return mlp.Down.Forward(ctx, hiddenStates)
}
// VisionEncoderLayer implements an encoder layer for the Qwen vision model
type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
Norm2 *nn.RMSNorm `gguf:"ln2"`
SelfAttention *VisionSelfAttention
MLP *VisionMLP
}
// Forward computes an encoder layer for the vision model
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positionIDs, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
// VisionModelOptions contains configuration options for the Qwen vision model
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeTheta float32
outHiddenSize int
}
type PatchEmbedding struct {
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
shape := pixelValues.Shape()
numChannels := 3
temporalPatchSize := 2
embedDim := 1280
numPatches := shape[1] / temporalPatchSize
// Split the input tensor into two temporal slices and process each separately
// First temporal slice (frame 0)
slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx)
reshaped0 := slice0.Reshape(ctx,
patchSize, // height
patchSize, // width
numChannels, // channels
numPatches) // batch
// Second temporal slice (frame 1)
slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx)
reshaped1 := slice1.Reshape(ctx,
patchSize, // height
patchSize, // width
numChannels, // channels
numPatches) // batch
// Apply the appropriate convolution to each temporal slice
// PatchConv0 corresponds to weights for temporal frame 0
// PatchConv1 corresponds to weights for temporal frame 1
s0, s1 := patchSize, patchSize // Use full stride as in original
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1)
output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1)
// Add the outputs from the two temporal convolutions
combined := output0.Add(ctx, output1)
// Reshape to required output dimensions
result := combined.Reshape(ctx, embedDim, numPatches)
fmt.Println(ml.Dump(ctx, result))
return result
}
// VisionPatchMerger implements patch merging for the Qwen vision model
type VisionPatchMerger struct {
LNQ *nn.RMSNorm `gguf:"ln_q"`
MLP *nn.Linear `gguf:"mlp"`
}
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contextDim, spatialMergeSize int) ml.Tensor {
hiddenSize := contextDim * (spatialMergeSize * spatialMergeSize)
// Normalize and reshape
x = pm.LNQ.Forward(ctx, x, 1e-6)
x = x.Reshape(ctx, -1, hiddenSize)
// Apply MLP for merging
x = pm.MLP.Forward(ctx, x)
return x
}
// VisionModel implements the Qwen vision model
type VisionModel struct {
PatchEmbedding *PatchEmbedding
Layers []VisionEncoderLayer `gguf:"blk"`
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
PatchMerger *VisionPatchMerger `gguf:"patch_merger"`
*VisionModelOptions
}
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
// Calculate position IDs for 2D RoPE
numPatchesH := pixelValues.Dim(0) / m.patchSize
numPatchesW := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesH * numPatchesW
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
positions := make([]int32, numPatches*4)
for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ {
idx := h*numPatchesW + w
// For each position, store both h and w coordinates twice
// This matches the pattern seen in the C++ implementation
positions[idx*4] = int32(h) // y coordinate
positions[idx*4+1] = int32(w) // x coordinate
positions[idx*4+2] = int32(h) // y coordinate (repeated)
positions[idx*4+3] = int32(w) // x coordinate (repeated)
}
}
// Create the position IDs tensor with correct dimensions
positionIDs, err := ctx.Input().FromIntSlice(positions, numPatches*4)
if err != nil {
panic(err)
}
// Apply encoder layers
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
}
hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
return hiddenStates
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
ropeTheta := c.Float("vision.rope_theta", 10000.0) // not set
outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set
numHeads := int(c.Uint("vision.attention.head_count", 16))
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
intermediateSize: int(c.Uint("vision.feed_forward_length", 0)),
imageSize: int(c.Uint("vision.image_size", 560)),
patchSize: patchSize,
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: ropeTheta,
outHiddenSize: outHiddenSize,
},
}
}

View File

@ -0,0 +1,196 @@
package qwen25vl
import (
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
type ImageProcessor struct {
imageSize int
numChannels int
patchSize int
temporalPatchSize int
mergeSize int
minPixels int
maxPixels int
factor int
rescaleFactor float32
imageMean []float32
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 560)),
numChannels: 3,
patchSize: patchSize,
temporalPatchSize: 2,
mergeSize: mergeSize,
minPixels: 56 * 56,
maxPixels: 28 * 28 * 4 * 1280,
factor: patchSize * mergeSize,
rescaleFactor: 1.0 / 255.0,
imageMean: []float32{0.48145466, 0.4578275, 0.40821073},
imageStd: []float32{0.26862954, 0.26130258, 0.27577711},
}
}
// SmartResize implements the smart resize algorithm
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
if height < factor || width < factor {
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
} else if float64(max(height, width))/float64(min(height, width)) > 200 {
aspectRatio := float64(max(height, width)) / float64(min(height, width))
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %f", aspectRatio))
}
round := func(x float64) int {
return int(math.Round(x))
}
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
if hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) {
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image using existing functions
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
normalizedPixels := imageproc.Normalize(
resizedImg,
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
true, // rescale
true, // channelFirst
)
// Calculate grid dimensions
gridH := resizedHeight / p.patchSize
gridW := resizedWidth / p.patchSize
gridT := 1 // For single images, temporal dimension is 1
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT)
if err != nil {
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err)
}
// Return patches and grid dimensions
return patches, gridT, gridH, gridW, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) {
channels := p.numChannels
patchSize := p.patchSize
mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions
numPatches := gridT * gridH * gridW
patchDim := channels * temporalPatchSize * patchSize * patchSize
// Create output tensor
result := make([]float32, numPatches*patchDim)
// Instead of the complex 9D reshape+transpose, directly extract patches
// in the format expected by the forward pass
patchIndex := 0
for t := 0; t < gridT; t++ {
// For each patch in the grid
for h := 0; h < gridH; h += mergeSize {
for w := 0; w < gridW; w += mergeSize {
// Handle the 2x2 merged patches
for mh := 0; mh < mergeSize; mh++ {
for mw := 0; mw < mergeSize; mw++ {
// For each pixel in the patch
for py := 0; py < patchSize; py++ {
for px := 0; px < patchSize; px++ {
// Calculate source coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// For each channel
for c := 0; c < channels; c++ {
// Channel-first format (CHW)
srcIdx := c*height*width + y*width + x
// Calculate destination index based on the expected layout
// This is the key part that matches what the model expects
dstIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(0 * patchSize * patchSize) + // temporal dim
(py * patchSize) +
px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
// Handle temporal dimension padding (if needed)
for tp := 1; tp < temporalPatchSize; tp++ {
for py := 0; py < patchSize; py++ {
for px := 0; px < patchSize; px++ {
for c := 0; c < channels; c++ {
srcIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(0 * patchSize * patchSize) + // first temporal frame
(py * patchSize) +
px
dstIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(tp * patchSize * patchSize) + // current temporal frame
(py * patchSize) +
px
if srcIdx < len(result) && dstIdx < len(result) {
result[dstIdx] = result[srcIdx] // Copy from first frame
}
}
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@ -0,0 +1,47 @@
package qwen25vl
import (
"image"
_ "image/jpeg" // Register JPEG decoder
"testing"
)
func TestSmartResize(t *testing.T) {
type smartResizeCase struct {
TestImage image.Image
Expected image.Point
}
// Create an image processor with default values
processor := ImageProcessor{
imageSize: 560, // Example value
numChannels: 3,
factor: 28,
minPixels: 56 * 56,
maxPixels: 14 * 14 * 4 * 1280,
}
cases := []smartResizeCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 1024)),
Expected: image.Point{980, 980},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
Expected: image.Point{1036, 756},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
Expected: image.Point{980, 980},
},
}
for _, c := range cases {
b := c.TestImage.Bounds().Max
x, y := processor.SmartResize(b.X, b.Y)
actual := image.Point{x, y}
if actual != c.Expected {
t.Errorf("expected: %v, actual: %v", c.Expected, actual)
}
}
}

View File

@ -14,7 +14,7 @@ import (
const (
DefaultFactor = 28
DefaultMinPixels = 56 * 56
DefaultMaxPixels = 14 * 14 * 4 * 1280
DefaultMaxPixels = 14 * 14 * 4 * 1280 // TODO: might need to change
)
// smartResize calculates the size of the image to resize to based on the

View File

@ -497,43 +497,37 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return nil, err
}
var offset int64
for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 0)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
f, n, err := ggml.Decode(blob, 0)
if err != nil {
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
mediatype = "application/vnd.ollama.image.projector"
}
var layer Layer
if digest != "" && n == stat.Size() {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
mediatype = "application/vnd.ollama.image.projector"
}
var layer Layer
if digest != "" && n == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f})
offset = n
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, 0, n), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f})
return detectChatTemplate(layers)
}