mirror of https://github.com/ollama/ollama.git
restructure
image processing Update model.go Update model.go Update model.go no projector no projector vision model scaffold ... ... wip ... rebase fix patch merger tidy ... Update model_vision.go server: do not attempt to parse offset file as gguf This logic was causing issues for me when importing a gguf that had some padding at the end of the file. The valid gguf would be read, but then it would try to read the offset as a different gguf file. This does not seem right. Update process_image_test.go apply norm prompt processing prompt processing fix post tokenize fix gguf padding + populate the split patch embeddings ... ... another shot at patch embeddings ... patch embedding Update model_vision.go split pixels
This commit is contained in:
parent
7e55273f89
commit
57823c39b5
|
@ -198,6 +198,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
conv = &qwen2Model{}
|
||||
case "Qwen2_5_VLForConditionalGeneration":
|
||||
conv = &qwen25vlModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
type qwen25vlModel struct {
|
||||
ModelParameters
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
|
||||
VisionModel struct {
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
//HeadDim uint32 `json:"num_heads"`
|
||||
//RopeTheta float32 `json:"rope_theta"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
WindowSize uint32 `json:"window_size"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen25vlModel)(nil)
|
||||
|
||||
func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
kv["qwen25vl.block_count"] = q.HiddenLayers
|
||||
kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings
|
||||
kv["qwen25vl.embedding_length"] = q.HiddenSize
|
||||
kv["qwen25vl.feed_forward_length"] = q.IntermediateSize
|
||||
kv["qwen25vl.attention.head_count"] = q.NumAttentionHeads
|
||||
kv["qwen25vl.attention.head_count_kv"] = q.NumKeyValueHeads
|
||||
kv["qwen25vl.rope.freq_base"] = q.RopeTheta
|
||||
kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
|
||||
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") {
|
||||
// var buf bytes.Buffer
|
||||
// if _, err := t.WriteTo(&buf); err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape())
|
||||
// out = append(out, newTensors...)
|
||||
// } else if strings.HasPrefix(t.Name(), "v.blk.") {
|
||||
// skip
|
||||
} else {
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *qwen25vlModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"visual.blocks", "v.blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
}
|
||||
|
||||
func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor {
|
||||
slog.Debug("patch stuff", "kind", kind, "shape", shape)
|
||||
|
||||
if kind != tensorKindF16 {
|
||||
panic("tensor is of wrong type")
|
||||
}
|
||||
|
||||
if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) {
|
||||
panic("wrong sized tensor")
|
||||
}
|
||||
|
||||
// determine the size of the tensor based on its shape
|
||||
shapeToSize := func(s []int) int {
|
||||
r := 1
|
||||
for _, n := range s {
|
||||
r *= int(n)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// tensor.WithShape() wants []int
|
||||
intShape := make([]int, len(shape))
|
||||
for i, v := range shape {
|
||||
intShape[i] = int(v)
|
||||
}
|
||||
|
||||
u16s := make([]uint16, shapeToSize(intShape))
|
||||
if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil {
|
||||
panic("bad read")
|
||||
}
|
||||
|
||||
f32s := make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
||||
}
|
||||
|
||||
newTensors := []ggml.Tensor{}
|
||||
|
||||
getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed {
|
||||
slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape)
|
||||
n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s))
|
||||
t, err := n.Slice(s...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2])
|
||||
|
||||
var f16s patchEmbed
|
||||
for _, row := range ts {
|
||||
for _, col := range row {
|
||||
f16s = append(f16s, float16.Fromfloat32(col).Bits())
|
||||
}
|
||||
}
|
||||
|
||||
return f16s
|
||||
}
|
||||
|
||||
p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil})
|
||||
newTensors = append(newTensors, ggml.Tensor{
|
||||
Name: "v.patch_embed.0.weight",
|
||||
Kind: kind,
|
||||
Shape: append(shape[:2], shape[3:]...),
|
||||
WriterTo: p,
|
||||
})
|
||||
|
||||
p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil})
|
||||
newTensors = append(newTensors, ggml.Tensor{
|
||||
Name: "v.patch_embed.1.weight",
|
||||
Kind: kind,
|
||||
Shape: append(shape[:2], shape[3:]...),
|
||||
WriterTo: p,
|
||||
})
|
||||
|
||||
return newTensors
|
||||
}
|
||||
|
||||
type patchEmbed []uint16
|
||||
|
||||
func (t patchEmbed) WriteTo(w io.Writer) (int64, error) {
|
||||
err := binary.Write(w, binary.LittleEndian, t)
|
||||
return 0, err
|
||||
}
|
|
@ -541,6 +541,7 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
|||
})
|
||||
|
||||
var s uint64
|
||||
var alignment int64 = 32
|
||||
for _, t := range ts {
|
||||
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
|
||||
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
||||
|
@ -655,5 +656,9 @@ func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
|
|||
}
|
||||
|
||||
func ggufPadding(offset, align int64) int64 {
|
||||
// if we already fit perfectly onto a 16 byte boundary, don't bother padding
|
||||
if ((align-offset%align)%align)%16 == 0 {
|
||||
return 0
|
||||
}
|
||||
return (align - offset%align) % align
|
||||
}
|
||||
|
|
|
@ -551,10 +551,16 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, co
|
|||
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
|
||||
|
|
|
@ -191,6 +191,7 @@ type Tensor interface {
|
|||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
|
||||
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, sections [4]int, config RoPEConfig) Tensor
|
||||
|
||||
Sin(ctx Context) Tensor
|
||||
Cos(ctx Context) Tensor
|
||||
|
|
|
@ -1042,15 +1042,6 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
// GGML RoPE types
|
||||
// These are the types used in the C implementation of RoPE
|
||||
const (
|
||||
ropeTypeNorm C.int = 0
|
||||
ropeTypeNeox C.int = 2
|
||||
ropeTypeMrope C.int = 8
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
// RoPE applies Rotary Position Embeddings to the tensor
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
|
@ -1066,21 +1057,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
|
|||
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
|
||||
}
|
||||
|
||||
// Map Go RopeType to C implementation constants
|
||||
var ropeTypeC C.int
|
||||
switch config.Type {
|
||||
case ml.RopeTypeNormal:
|
||||
ropeTypeC = ropeTypeNorm
|
||||
case ml.RopeTypeNeox:
|
||||
ropeTypeC = ropeTypeNeox
|
||||
case ml.RopeTypeMRoPE:
|
||||
ropeTypeC = ropeTypeMrope
|
||||
case ml.RopeTypeVision:
|
||||
ropeTypeC = ropeTypeVision
|
||||
default:
|
||||
ropeTypeC = ropeTypeNorm
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
|
@ -1089,7 +1065,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config
|
|||
positionIDs.(*Tensor).t,
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(config.Dim),
|
||||
ropeTypeC,
|
||||
ropeTypeToC(config.Type),
|
||||
C.int(config.YarnCtxTrain),
|
||||
C.float(config.Base),
|
||||
C.float(config.Scale),
|
||||
|
@ -1107,6 +1083,60 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
|||
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
|
||||
}
|
||||
}
|
||||
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, sections [4]int, config ml.RoPEConfig) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
if C.ggml_is_quantized(t.t._type) {
|
||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_multi(
|
||||
ctx.(*Context).ctx,
|
||||
dequant,
|
||||
positionIDs.(*Tensor).t,
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(config.Dim),
|
||||
(*C.int)(unsafe.Pointer(§ions[0])),
|
||||
ropeTypeToC(config.Type),
|
||||
C.int(config.YarnCtxTrain),
|
||||
C.float(config.Base),
|
||||
C.float(config.Scale),
|
||||
C.float(config.YarnExtFactor),
|
||||
C.float(config.YarnAttnFactor),
|
||||
C.float(config.YarnBetaFast),
|
||||
C.float(config.YarnBetaSlow),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GGML RoPE types
|
||||
// These are the types used in the C implementation of RoPE
|
||||
const (
|
||||
ropeTypeNorm C.int = 0
|
||||
ropeTypeNeox C.int = 2
|
||||
ropeTypeMrope C.int = 8
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
func ropeTypeToC(ropeType ml.RopeType) C.int {
|
||||
switch ropeType {
|
||||
case ml.RopeTypeNormal:
|
||||
return ropeTypeNorm
|
||||
case ml.RopeTypeNeox:
|
||||
return ropeTypeNeox
|
||||
case ml.RopeTypeMRoPE:
|
||||
return ropeTypeMrope
|
||||
case ml.RopeTypeVision:
|
||||
return ropeTypeVision
|
||||
default:
|
||||
return ropeTypeNorm
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func setup(t *testing.T) ml.Backend {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
models := filepath.Join(home, ".ollama", "models")
|
||||
|
||||
b, err := New(context.TODO(), filepath.Join(models, "blobs", "sha256-667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"), ml.BackendParams{NumGPULayers: 99})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestUnfoldConv(t *testing.T) {
|
||||
b := setup(t)
|
||||
ctx := b.NewContext().Input()
|
||||
t.Cleanup(func() { ctx.Close() })
|
||||
|
||||
tiles, channels, height, width := 5, 3, 336, 336
|
||||
patchSize := 14
|
||||
|
||||
tt := ctx.Arange(0, float32(tiles*channels*height*width), 1, ml.DTypeF32).Reshape(ctx, width, height, channels, tiles)
|
||||
t.Log("tt", tt.Shape())
|
||||
t.Log(ml.Dump(ctx, tt))
|
||||
|
||||
kernel := ctx.Empty(ml.DTypeF32, patchSize, patchSize, channels)
|
||||
t.Log("kernel", kernel.Shape())
|
||||
t.Log(ml.Dump(ctx, kernel))
|
||||
|
||||
tt = kernel.IM2Col(ctx, tt, patchSize, patchSize, 0, 0, 1, 1)
|
||||
t.Log("tt", tt.Shape())
|
||||
t.Log(ml.Dump(ctx, tt))
|
||||
|
||||
tt = tt.Reshape(ctx, tt.Dim(0), tt.Dim(1)*tt.Dim(2), tt.Dim(3))
|
||||
t.Log("tt", tt.Shape())
|
||||
t.Log(ml.Dump(ctx, tt))
|
||||
}
|
|
@ -57,7 +57,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||
},
|
||||
),
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextOptions: &TextOptions{
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
|
|
|
@ -17,6 +17,7 @@ type TextOptions struct {
|
|||
hiddenSize, numHeads, numKVHeads, headDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
|
@ -40,7 +41,6 @@ type SelfAttention struct {
|
|||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(0)
|
||||
headDim := opts.headDim
|
||||
if headDim == 0 {
|
||||
headDim = opts.hiddenSize / opts.numHeads
|
||||
|
@ -48,11 +48,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
@ -63,7 +63,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, m.TextOptions.ropeConfig), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
@ -167,9 +167,13 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
|
|||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base", 10000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("rope.dimension_count"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"image"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
|
@ -12,147 +13,151 @@ import (
|
|||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||
eps float32
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*PatchMerger `gguf:"mm"`
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
// Implement MultimodalProcessor interface
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type PatchMerger struct {
|
||||
MLPLayer1 *nn.Linear `gguf:"0"`
|
||||
MLPLayer2 *nn.Linear `gguf:"2"`
|
||||
}
|
||||
|
||||
// Forward computes patch merging for the vision model
|
||||
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
// Get dimensions
|
||||
hiddenSize := visionOutputs.Dim(0)
|
||||
numPositions := visionOutputs.Dim(1)
|
||||
batchSize := visionOutputs.Dim(2)
|
||||
|
||||
reshaped := visionOutputs.Reshape(ctx, hiddenSize*4, numPositions/4, batchSize)
|
||||
|
||||
// Apply first linear layer (mm_0_w, mm_0_b)
|
||||
hidden := pm.MLPLayer1.Forward(ctx, reshaped)
|
||||
|
||||
activated := hidden.GELU(ctx)
|
||||
|
||||
// Apply second linear layer (mm_1_w, mm_1_b)
|
||||
output := pm.MLPLayer2.Forward(ctx, activated)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
TextModel: NewTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
ctxLen: int(c.Uint("context_length")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base"),
|
||||
Scale: c.Float("rope.freq_scale", 1),
|
||||
Dim: c.Uint("rope.dimension_count", 128),
|
||||
Type: ml.RopeTypeNeox,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 32768))),
|
||||
},
|
||||
},
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type imageFeatures struct {
|
||||
Tensor ml.Tensor
|
||||
GridT int
|
||||
GridH int
|
||||
GridW int
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// SelfAttention implements the multi-head self-attention mechanism
|
||||
// with separate projections for query, key, value and output transformations
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
// Apply SwiGLU activation gating
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
// Project back to hidden dimension
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer combining self-attention and feed-forward components
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
// Self-attention branch with residual connection
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
// Feed-forward branch with residual connection
|
||||
residual = hiddenState
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
numPatches := gridT * gridH * gridW
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||
|
||||
return &imageFeatures{
|
||||
Tensor: visionOutputs,
|
||||
GridT: gridT,
|
||||
GridH: gridH,
|
||||
GridW: gridW,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
// Get image token IDs from config
|
||||
imageToken := 151655
|
||||
visionStartToken := 151652
|
||||
visionEndToken := 151653
|
||||
|
||||
// Get merge size from config
|
||||
mergeSize := m.ImageProcessor.mergeSize
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
// If not a multimodal input, add it to the result unchanged
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
// This is an image token with multimodal data
|
||||
features := inp.Multimodal.(*imageFeatures)
|
||||
|
||||
// Get grid dimensions from the features
|
||||
gridT := features.GridT
|
||||
gridH := features.GridH
|
||||
gridW := features.GridW
|
||||
|
||||
// Calculate tokens per grid based on grid dimensions
|
||||
mergeLength := mergeSize * mergeSize
|
||||
gridProduct := gridT * gridH * gridW
|
||||
tokensPerGrid := gridProduct / mergeLength
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, input.Input{Token: int32(visionStartToken)})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, input.Input{
|
||||
Token: int32(imageToken),
|
||||
Multimodal: features.Tensor,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
for range tokensPerGrid - 1 {
|
||||
result = append(result, input.Input{Token: int32(imageToken)})
|
||||
}
|
||||
|
||||
result = append(result, input.Input{Token: int32(visionEndToken)})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
// Convert input tokens and positions to tensors
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -163,25 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Initial token embedding
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
// Process through transformer layers
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen25vl", New)
|
||||
model.Register("qwen2vl", New)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestPostTokenize(t *testing.T) {
|
||||
// Set up test inputs
|
||||
model := &Model{}
|
||||
mockHash := uint64(12345678)
|
||||
|
||||
inputs := []input.Input{
|
||||
{Token: 123}, // Regular token
|
||||
{Token: 456}, // Regular token
|
||||
{Token: 151655, Multimodal: &ggml.Tensor{}, MultimodalHash: mockHash}, // Image token
|
||||
{Token: 789}, // Regular token
|
||||
}
|
||||
|
||||
// Run the function being tested
|
||||
result, err := model.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
t.Fatalf("PostTokenize returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the actual length first
|
||||
expectedLength := 21
|
||||
if len(result) != expectedLength {
|
||||
t.Fatalf("Result has wrong length: got %d, expected %d", len(result), expectedLength)
|
||||
}
|
||||
|
||||
// Check key positions only
|
||||
checkPositions := map[int]int32{
|
||||
0: 123, // First regular token
|
||||
1: 456, // Second regular token
|
||||
2: 151652, // Vision start token
|
||||
4: 151655, // First placeholder token
|
||||
19: 151653, // Vision end token
|
||||
20: 789, // Final regular token
|
||||
}
|
||||
|
||||
for pos, expectedToken := range checkPositions {
|
||||
if pos >= len(result) {
|
||||
t.Errorf("Position %d is out of bounds (result length: %d)", pos, len(result))
|
||||
continue
|
||||
}
|
||||
if result[pos].Token != expectedToken {
|
||||
t.Errorf("Position %d: expected token %d, got %d", pos, expectedToken, result[pos].Token)
|
||||
}
|
||||
}
|
||||
|
||||
// Check multimodal data is preserved
|
||||
if result[3].MultimodalHash != mockHash {
|
||||
t.Errorf("Multimodal hash not preserved: got %d, expected %d",
|
||||
result[3].MultimodalHash, mockHash)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||
eps float32
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
func NewTextModel(c fs.Config) *TextModel {
|
||||
m := TextModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
ctxLen: int(c.Uint("context_length")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base"),
|
||||
Scale: c.Float("rope.freq_scale", 1),
|
||||
Dim: c.Uint("rope.dimension_count", 128),
|
||||
Type: ml.RopeTypeNeox,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 128000))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
// SelfAttention implements the multi-head self-attention mechanism
|
||||
// with separate projections for query, key, value and output transformations
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Apply SwiGLU activation gating
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
// Project back to hidden dimension
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer combining self-attention and feed-forward components
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
// Self-attention branch with residual connection
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
// Feed-forward branch with residual connection
|
||||
residual = hiddenState
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
// Initial token embedding
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
// Process through transformer layers
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
|
@ -0,0 +1,260 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
// VisionSelfAttention implements self-attention for the Qwen vision model
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
// Forward computes self-attention for the vision model
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
config := ml.RoPEConfig{
|
||||
Dim: uint32(opts.headDim / 2),
|
||||
Type: ml.RopeTypeMRoPE,
|
||||
Base: opts.ropeTheta,
|
||||
Scale: 1.0,
|
||||
YarnConfig: ml.DefaultYarnConfig(128000),
|
||||
}
|
||||
|
||||
query = query.RoPEMulti(
|
||||
ctx,
|
||||
positionIDs,
|
||||
nil,
|
||||
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
|
||||
config,
|
||||
)
|
||||
key = key.RoPEMulti(
|
||||
ctx,
|
||||
positionIDs,
|
||||
nil,
|
||||
[4]int{opts.headDim / 4, opts.headDim / 4, opts.headDim / 4, opts.headDim / 4},
|
||||
config,
|
||||
)
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, scale, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
// VisionMLP implements the MLP for the Qwen vision model
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
// Forward computes the MLP for the vision model
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Using GEGLU activation: (Gate * Up) * GELU(Gate)
|
||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = gateOutput.GELU(ctx).Mul(ctx, upOutput)
|
||||
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
// VisionEncoderLayer implements an encoder layer for the Qwen vision model
|
||||
type VisionEncoderLayer struct {
|
||||
Norm1 *nn.RMSNorm `gguf:"ln1"`
|
||||
Norm2 *nn.RMSNorm `gguf:"ln2"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
// Forward computes an encoder layer for the vision model
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positionIDs, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// VisionModelOptions contains configuration options for the Qwen vision model
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
intermediateSize int
|
||||
imageSize int
|
||||
patchSize int
|
||||
numChannels int
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
outHiddenSize int
|
||||
}
|
||||
|
||||
type PatchEmbedding struct {
|
||||
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
|
||||
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||
}
|
||||
|
||||
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSize int) ml.Tensor {
|
||||
shape := pixelValues.Shape()
|
||||
numChannels := 3
|
||||
temporalPatchSize := 2
|
||||
embedDim := 1280
|
||||
numPatches := shape[1] / temporalPatchSize
|
||||
|
||||
// Split the input tensor into two temporal slices and process each separately
|
||||
// First temporal slice (frame 0)
|
||||
slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx)
|
||||
reshaped0 := slice0.Reshape(ctx,
|
||||
patchSize, // height
|
||||
patchSize, // width
|
||||
numChannels, // channels
|
||||
numPatches) // batch
|
||||
|
||||
// Second temporal slice (frame 1)
|
||||
slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx)
|
||||
reshaped1 := slice1.Reshape(ctx,
|
||||
patchSize, // height
|
||||
patchSize, // width
|
||||
numChannels, // channels
|
||||
numPatches) // batch
|
||||
|
||||
// Apply the appropriate convolution to each temporal slice
|
||||
// PatchConv0 corresponds to weights for temporal frame 0
|
||||
// PatchConv1 corresponds to weights for temporal frame 1
|
||||
s0, s1 := patchSize, patchSize // Use full stride as in original
|
||||
p0, p1 := 0, 0 // padding
|
||||
d0, d1 := 1, 1 // dilation
|
||||
|
||||
output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1)
|
||||
output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
// Add the outputs from the two temporal convolutions
|
||||
combined := output0.Add(ctx, output1)
|
||||
|
||||
// Reshape to required output dimensions
|
||||
result := combined.Reshape(ctx, embedDim, numPatches)
|
||||
|
||||
fmt.Println(ml.Dump(ctx, result))
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// VisionPatchMerger implements patch merging for the Qwen vision model
|
||||
type VisionPatchMerger struct {
|
||||
LNQ *nn.RMSNorm `gguf:"ln_q"`
|
||||
MLP *nn.Linear `gguf:"mlp"`
|
||||
}
|
||||
|
||||
// Forward computes patch merging for the vision model
|
||||
func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contextDim, spatialMergeSize int) ml.Tensor {
|
||||
hiddenSize := contextDim * (spatialMergeSize * spatialMergeSize)
|
||||
|
||||
// Normalize and reshape
|
||||
x = pm.LNQ.Forward(ctx, x, 1e-6)
|
||||
x = x.Reshape(ctx, -1, hiddenSize)
|
||||
|
||||
// Apply MLP for merging
|
||||
x = pm.MLP.Forward(ctx, x)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// VisionModel implements the Qwen vision model
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *PatchEmbedding
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
|
||||
PatchMerger *VisionPatchMerger `gguf:"patch_merger"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
// Forward computes the vision model for an input tensor
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
// Calculate position IDs for 2D RoPE
|
||||
numPatchesH := pixelValues.Dim(0) / m.patchSize
|
||||
numPatchesW := pixelValues.Dim(1) / m.patchSize
|
||||
numPatches := numPatchesH * numPatchesW
|
||||
|
||||
// Extract patch embeddings
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize)
|
||||
|
||||
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
|
||||
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
|
||||
positions := make([]int32, numPatches*4)
|
||||
|
||||
for h := 0; h < numPatchesH; h++ {
|
||||
for w := 0; w < numPatchesW; w++ {
|
||||
idx := h*numPatchesW + w
|
||||
// For each position, store both h and w coordinates twice
|
||||
// This matches the pattern seen in the C++ implementation
|
||||
positions[idx*4] = int32(h) // y coordinate
|
||||
positions[idx*4+1] = int32(w) // x coordinate
|
||||
positions[idx*4+2] = int32(h) // y coordinate (repeated)
|
||||
positions[idx*4+3] = int32(w) // x coordinate (repeated)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the position IDs tensor with correct dimensions
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, numPatches*4)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Apply encoder layers
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
// newVisionModel creates a new instance of the Qwen vision model
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
|
||||
ropeTheta := c.Float("vision.rope_theta", 10000.0) // not set
|
||||
outHiddenSize := int(c.Uint("vision.out_embedding_length", 0)) // not set
|
||||
numHeads := int(c.Uint("vision.attention.head_count", 16))
|
||||
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: hiddenSize / numHeads,
|
||||
intermediateSize: int(c.Uint("vision.feed_forward_length", 0)),
|
||||
imageSize: int(c.Uint("vision.image_size", 560)),
|
||||
patchSize: patchSize,
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: ropeTheta,
|
||||
outHiddenSize: outHiddenSize,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,196 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
|
||||
type ImageProcessor struct {
|
||||
imageSize int
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
mergeSize int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
factor int
|
||||
rescaleFactor float32
|
||||
imageMean []float32
|
||||
imageStd []float32
|
||||
}
|
||||
|
||||
// newImageProcessor creates a new image processor with default values
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size", 560)),
|
||||
numChannels: 3,
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: 2,
|
||||
mergeSize: mergeSize,
|
||||
minPixels: 56 * 56,
|
||||
maxPixels: 28 * 28 * 4 * 1280,
|
||||
factor: patchSize * mergeSize,
|
||||
rescaleFactor: 1.0 / 255.0,
|
||||
imageMean: []float32{0.48145466, 0.4578275, 0.40821073},
|
||||
imageStd: []float32{0.26862954, 0.26130258, 0.27577711},
|
||||
}
|
||||
}
|
||||
|
||||
// SmartResize implements the smart resize algorithm
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
|
||||
if height < factor || width < factor {
|
||||
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
|
||||
} else if float64(max(height, width))/float64(min(height, width)) > 200 {
|
||||
aspectRatio := float64(max(height, width)) / float64(min(height, width))
|
||||
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %f", aspectRatio))
|
||||
}
|
||||
|
||||
round := func(x float64) int {
|
||||
return int(math.Round(x))
|
||||
}
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
|
||||
if hBar*wBar > p.maxPixels {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
|
||||
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if hBar*wBar < p.minPixels {
|
||||
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
|
||||
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) {
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
// Calculate smart resize dimensions
|
||||
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
|
||||
|
||||
// Resize image using existing functions
|
||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
|
||||
normalizedPixels := imageproc.Normalize(
|
||||
resizedImg,
|
||||
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
|
||||
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
|
||||
true, // rescale
|
||||
true, // channelFirst
|
||||
)
|
||||
|
||||
// Calculate grid dimensions
|
||||
gridH := resizedHeight / p.patchSize
|
||||
gridW := resizedWidth / p.patchSize
|
||||
gridT := 1 // For single images, temporal dimension is 1
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT)
|
||||
if err != nil {
|
||||
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err)
|
||||
}
|
||||
|
||||
// Return patches and grid dimensions
|
||||
return patches, gridT, gridH, gridW, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) {
|
||||
channels := p.numChannels
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.mergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
// Calculate output dimensions
|
||||
numPatches := gridT * gridH * gridW
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
|
||||
// Create output tensor
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
|
||||
// Instead of the complex 9D reshape+transpose, directly extract patches
|
||||
// in the format expected by the forward pass
|
||||
patchIndex := 0
|
||||
|
||||
for t := 0; t < gridT; t++ {
|
||||
// For each patch in the grid
|
||||
for h := 0; h < gridH; h += mergeSize {
|
||||
for w := 0; w < gridW; w += mergeSize {
|
||||
// Handle the 2x2 merged patches
|
||||
for mh := 0; mh < mergeSize; mh++ {
|
||||
for mw := 0; mw < mergeSize; mw++ {
|
||||
// For each pixel in the patch
|
||||
for py := 0; py < patchSize; py++ {
|
||||
for px := 0; px < patchSize; px++ {
|
||||
// Calculate source coordinates
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
|
||||
// For each channel
|
||||
for c := 0; c < channels; c++ {
|
||||
// Channel-first format (CHW)
|
||||
srcIdx := c*height*width + y*width + x
|
||||
|
||||
// Calculate destination index based on the expected layout
|
||||
// This is the key part that matches what the model expects
|
||||
dstIdx := patchIndex*patchDim +
|
||||
(c * temporalPatchSize * patchSize * patchSize) +
|
||||
(0 * patchSize * patchSize) + // temporal dim
|
||||
(py * patchSize) +
|
||||
px
|
||||
|
||||
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle temporal dimension padding (if needed)
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
for py := 0; py < patchSize; py++ {
|
||||
for px := 0; px < patchSize; px++ {
|
||||
for c := 0; c < channels; c++ {
|
||||
srcIdx := patchIndex*patchDim +
|
||||
(c * temporalPatchSize * patchSize * patchSize) +
|
||||
(0 * patchSize * patchSize) + // first temporal frame
|
||||
(py * patchSize) +
|
||||
px
|
||||
|
||||
dstIdx := patchIndex*patchDim +
|
||||
(c * temporalPatchSize * patchSize * patchSize) +
|
||||
(tp * patchSize * patchSize) + // current temporal frame
|
||||
(py * patchSize) +
|
||||
px
|
||||
|
||||
if srcIdx < len(result) && dstIdx < len(result) {
|
||||
result[dstIdx] = result[srcIdx] // Copy from first frame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"image"
|
||||
_ "image/jpeg" // Register JPEG decoder
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSmartResize(t *testing.T) {
|
||||
type smartResizeCase struct {
|
||||
TestImage image.Image
|
||||
Expected image.Point
|
||||
}
|
||||
|
||||
// Create an image processor with default values
|
||||
processor := ImageProcessor{
|
||||
imageSize: 560, // Example value
|
||||
numChannels: 3,
|
||||
factor: 28,
|
||||
minPixels: 56 * 56,
|
||||
maxPixels: 14 * 14 * 4 * 1280,
|
||||
}
|
||||
|
||||
cases := []smartResizeCase{
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 1024)),
|
||||
Expected: image.Point{980, 980},
|
||||
},
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
|
||||
Expected: image.Point{1036, 756},
|
||||
},
|
||||
{
|
||||
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
|
||||
Expected: image.Point{980, 980},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
b := c.TestImage.Bounds().Max
|
||||
x, y := processor.SmartResize(b.X, b.Y)
|
||||
actual := image.Point{x, y}
|
||||
if actual != c.Expected {
|
||||
t.Errorf("expected: %v, actual: %v", c.Expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,7 +14,7 @@ import (
|
|||
const (
|
||||
DefaultFactor = 28
|
||||
DefaultMinPixels = 56 * 56
|
||||
DefaultMaxPixels = 14 * 14 * 4 * 1280
|
||||
DefaultMaxPixels = 14 * 14 * 4 * 1280 // TODO: might need to change
|
||||
)
|
||||
|
||||
// smartResize calculates the size of the image to resize to based on the
|
||||
|
|
|
@ -497,43 +497,37 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < stat.Size() {
|
||||
f, n, err := ggml.Decode(blob, 0)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
f, n, err := ggml.Decode(blob, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if f.KV().Kind() == "adapter" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
var layer Layer
|
||||
if digest != "" && n == stat.Size() {
|
||||
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if f.KV().Kind() == "adapter" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
var layer Layer
|
||||
if digest != "" && n == stat.Size() && offset == 0 {
|
||||
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
|
||||
if layer.Digest == "" {
|
||||
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
offset = n
|
||||
}
|
||||
|
||||
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
|
||||
if layer.Digest == "" {
|
||||
layer, err = NewLayer(io.NewSectionReader(blob, 0, n), mediatype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue