add new gemma model (#11204)

* update patches

* cherry pick metal mean kernel

* cherry pick cuda mean kernel

* gemma3n
This commit is contained in:
Michael Yang 2025-06-25 21:47:09 -07:00 committed by GitHub
parent ad118d8b13
commit 73b642e6f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 6084 additions and 54 deletions

View File

@ -190,6 +190,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":

168
convert/convert_gemma3n.go Normal file
View File

@ -0,0 +1,168 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"gonum.org/v1/gonum/stat/distuv"
)
type gemma3nModel struct {
ModelParameters
TextModel struct {
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
AltupActiveIdx uint32 `json:"altup_active_idx"`
AltupCoefClip float32 `json:"altup_coef_clip"`
AltupCorrectScale bool `json:"altup_correct_scale"`
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
AltupNumInputs uint32 `json:"altup_num_inputs"`
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
IntermediateSize uint32 `json:"intermediate_size"`
LaurelRank uint32 `json:"laurel_rank"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
} `json:"text_config"`
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
norm := distuv.Normal{Mu: 0, Sigma: 1}
for _, v := range m.TextModel.ActivationSparsityPattern {
if !yield(float32(norm.Quantile(float64(v)))) {
break
}
}
})
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, t := range m.TextModel.LayerTypes {
if !yield(t == "sliding_attention") {
break
}
}
})
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
kv["gemma3n.laurel_rank"] = m.TextModel.LaurelRank
kv["gemma3n.num_kv_shared_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
return kv
}
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
out, ts := mergeTensors(ts,
merge{"altup_proj.*.weight", "altup_proj.weight"},
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
)
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "audio_tower"),
strings.Contains(t.Name(), "embed_audio"),
strings.Contains(t.Name(), "vision_tower"),
strings.Contains(t.Name(), "embed_vision"):
// TODO: handle audio and vision towers
continue
case strings.Contains(t.Name(), "altup_predict_coef"),
strings.Contains(t.Name(), "altup_correct_coef"):
if m.TextModel.AltupCoefClip > 0 {
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
if err != nil {
return nil, err
}
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *gemma3nModel) Replacements() []string {
return []string{
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"post_feedforward_layernorm", "post_ffw_norm",
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
"altup.", "altup_",
"modality_router", "router",
"prediction_coefs", "predict_coef",
"correction_coefs", "correct_coef",
"correct_output_scale", "correct_scale.weight",
"laurel.", "laurel_",
"linear_left", "l",
"linear_right", "r",
"post_laurel_norm", "post_norm",
}
}

View File

@ -10,4 +10,5 @@ type Config interface {
Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
}

View File

@ -166,6 +166,11 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
return val.values
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
return val.values
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",

View File

@ -609,6 +609,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFArray(ws, ggufTypeString, v)
case *array[string]:
err = writeGGUFArray(ws, ggufTypeString, v.values)
case []bool:
err = writeGGUFArray(ws, ggufTypeBool, v)
case *array[bool]:
err = writeGGUFArray(ws, ggufTypeBool, v.values)
default:
return fmt.Errorf("improper type for '%s'", k)
}

2
go.mod
View File

@ -25,6 +25,7 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
)
require (
@ -44,7 +45,6 @@ require (
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)

View File

@ -150,7 +150,7 @@ index 4cce5166..7f6617fa 100644
llama_model_loader::llama_model_loader(
const std::string & fname,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 3a4e72a3..831b68c0 100644
index 3a4e72a3..db62973f 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {

View File

@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
4 files changed, 59 insertions(+), 79 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index c22687e4..c5948e8f 100644
index dca22d8b..1f3a3956 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
if (!kv_self->find_slot(ubatch)) {
@ -41,7 +41,7 @@ index c22687e4..c5948e8f 100644
}
ggml_backend_sched_reset(sched.get());
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {

View File

@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
3 files changed, 192 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index becdae07..7a44b6cf 100644
index 955fec59..654e2f28 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
@ -59,7 +59,7 @@ index becdae07..7a44b6cf 100644
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
@ -195,7 +195,7 @@ index 607ded85..53b02634 100644
+ }
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 2d46176e..47383486 100644
index d027271f..4abd01d7 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
@ -257,7 +257,7 @@ index 2d46176e..47383486 100644
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
@ -266,7 +266,7 @@ index 2d46176e..47383486 100644
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;

View File

@ -0,0 +1,169 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Thu, 19 Jun 2025 08:05:21 +0300
Subject: [PATCH] metal : add mean kernel (#14267)
* metal : add mean kernel
ggml-ci
* cont : dedup implementation
ggml-ci
---
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
2 files changed, 67 insertions(+), 14 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index ee4f2dcb..f20f5615 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+ GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (dst->op) {
+ case GGML_OP_SUM_ROWS:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+ break;
+ case GGML_OP_MEAN:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ }
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
+ nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 9cfddf45..08e8d807 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -956,31 +956,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
+template <bool norm>
kernel void kernel_sum_rows(
+ constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
- constant ggml_metal_kargs_sum_rows & args,
- uint3 tpig[[thread_position_in_grid]]) {
- int64_t i3 = tpig.z;
- int64_t i2 = tpig.y;
- int64_t i1 = tpig.x;
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ int64_t i3 = tgpig.z;
+ int64_t i2 = tgpig.y;
+ int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
- float row_sum = 0;
+ float sumf = 0;
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
- row_sum += src_row[i0];
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+ sumf += src_row[i0];
}
- dst_row[0] = row_sum;
+ sumf = simd_sum(sumf);
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ if (tpitg.x == 0) {
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
+ }
}
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
+
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
+
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

File diff suppressed because it is too large Load Diff

View File

@ -253,6 +253,7 @@ type Tensor interface {
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Div(ctx Context, t2 Tensor) Tensor
@ -276,6 +277,7 @@ type Tensor interface {
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor
@ -297,6 +299,12 @@ type Tensor interface {
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
Mean(ctx Context) Tensor
Variance(ctx Context) Tensor
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Clamp(ctx Context, min, max float32) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@ -297,7 +297,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
}
case contains(t.Name, "cls", "output", "output_norm"):
case contains(t.Name, "cls", "output", "output_norm",
"altup_proj", "altup_unembd_proj",
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible
@ -893,6 +895,13 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
if dim < 0 || dim >= C.GGML_MAX_DIMS {
panic("invalid dimension")
@ -1200,6 +1209,13 @@ func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
@ -1275,3 +1291,42 @@ func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
}
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
Sqr(ctx).
SumRows(ctx).
Scale(ctx, 1/float64(t.Dim(0)))
}
func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
return t.Variance(ctx).Sqrt(ctx)
}
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
}
}

View File

@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll

View File

@ -35,6 +35,7 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
@ -2322,6 +2323,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
@ -3211,6 +3215,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;

View File

@ -0,0 +1,19 @@
#include "mean.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1,25 +1,9 @@
#include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}

View File

@ -1,5 +1,4 @@
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -3434,31 +3434,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float row_sum = 0;
float sumf = 0;
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
dst_row[0] = row_sum;
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
id<MTLComputePipelineState> pipeline = nil;
switch (dst->op) {
case GGML_OP_SUM_ROWS:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
break;
case GGML_OP_MEAN:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{

View File

@ -956,31 +956,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float row_sum = 0;
float sumf = 0;
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
dst_row[0] = row_sum;
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,

View File

@ -0,0 +1,52 @@
package gemma3n
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
*TextModel
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.TextModel.Forward(ctx, batch, m.Cache)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
}
// TODO: setup hybrid (local sliding window + global) cache
m.Cache = kvcache.NewWrapperCache(
kvcache.NewCausalCache(m.Shift),
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
)
return &m, nil
}
func init() {
model.Register("gemma3n", New)
}

View File

@ -0,0 +1,360 @@
package gemma3n
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
type TextModel struct {
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
*PerLayerProjector
AltupEmbd *nn.Linear `gguf:"altup_proj"`
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
TextLayers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
TextOptions
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
// Create a tensor of a single float32 value of 1.0 to use for altup correction
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
hiddenStates := inputs.Concat(ctx, altupProj, 2)
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
for i, layer := range m.TextLayers {
if i < firstSharedKeyValue {
cache.SetLayer(i)
} else if m.isLocal(i) {
cache.SetLayer(firstSharedKeyValue - 2)
} else {
cache.SetLayer(firstSharedKeyValue - 1)
}
var layerType int
ropeBase := m.ropeBase
if m.isLocal(i) {
layerType = 1
ropeBase = m.ropeBaseLocal
}
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
// inputPerLayer = inputsPerLayer[:, i, :]
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
}
// hiddenStates = hiddenStates[:, :, 0]
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
// hiddenState = hiddenStates[:, :, 1:]
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.ropeBase
if m.isLocal(layer) {
ropeBase = m.ropeBaseLocal
}
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextScaledWordEmbedding struct {
*nn.Embedding
}
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
}
type PerLayerProjector struct {
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
Projector *nn.Linear `gguf:"per_layer_model_proj"`
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
}
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
perLayerProjection := p.Projector.Forward(ctx, inputs)
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
if inputsPerLayer != nil {
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
}
return perLayerProjection
}
type TextLayer struct {
*AltUp
*Laurel
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *TextAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
PerLayerProjection *nn.Linear `gguf:"proj"`
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
}
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
predictions := d.Predict(ctx, hiddenStates, opts)
active := opts.altupActive(ctx, predictions)
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
laurel := d.Laurel.Forward(ctx, attn, opts)
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
attn = active.Add(ctx, attn)
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
mlp = attn.Add(ctx, mlp)
predictions = d.Correct(ctx, predictions, mlp, one, opts)
active = opts.altupActive(ctx, predictions)
if opts.altupCorrectScale {
active = d.ScaleCorrectedOutput(ctx, active)
}
active = d.PerLayerInputGate.Forward(ctx, active)
active = active.GELU(ctx)
active = active.Mul(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
// inactive := predictions[:, :, 1:]
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
active = inactive.Add(ctx, active)
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
return predictions0.Concat(ctx, active, 2)
}
type AltUp struct {
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
Router *nn.Linear `gguf:"altup_router"`
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
}
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
}
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
predictions := coefficients.Mulmat(ctx, hiddenStates)
predictions = predictions.Add(ctx, hiddenStates)
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
}
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
modalities := a.computeRouterModalities(ctx, activated, opts)
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
coefficients = coefficients.Add(ctx, one)
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
corrected := innovation.Mul(ctx, coefficients)
corrected = corrected.Add(ctx, predictions)
return corrected
}
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
return predictions.Mul(ctx, a.CorrectionScale)
}
type Laurel struct {
LinearLeft *nn.Linear `gguf:"laurel_l"`
LinearRight *nn.Linear `gguf:"laurel_r"`
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
}
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates.Add(ctx, residual)
}
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor
if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
value = value.RMSNorm(ctx, nil, opts.eps)
}
attention := nn.Attention(ctx, query, key, value, 1., cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
upStates := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
if activationSparsityScale > 0 {
mean := hiddenStates.Mean(ctx)
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
cutoff := mean.Add(ctx, std)
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
}
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates
}
type TextOptions struct {
hiddenLayers int
hiddenSize int
hiddenSizePerLayerInput int
numHeads, numKVHeads int
keyLength, valueLength int
sharedKeyValueLayers int
altupActiveIndex int
altupInputs int
altupCorrectScale bool
eps float32
ropeBase float32
ropeBaseLocal float32
ropeScale float32
slidingWindowPattern []bool
activationSparsityScale []float32
}
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
// t[:, :, o.altupActiveIndex]
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
}
func (o *TextOptions) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o *TextOptions) isLocal(i int) bool {
return o.slidingWindowPattern[i]
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
TextLayers: make([]TextLayer, c.Uint("block_count")),
TextOptions: TextOptions{
hiddenLayers: int(c.Uint("block_count")),
hiddenSize: int(c.Uint("embedding_length")),
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
altupActiveIndex: int(c.Uint("altup.active_idx")),
altupInputs: int(c.Uint("altup.num_inputs")),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
ropeScale: c.Float("rope.freq_scale", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"),
},
}
}

View File

@ -3,6 +3,7 @@ package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"