mirror of https://github.com/ollama/ollama.git
add new gemma model (#11204)
* update patches * cherry pick metal mean kernel * cherry pick cuda mean kernel * gemma3n
This commit is contained in:
parent
ad118d8b13
commit
73b642e6f3
|
@ -190,6 +190,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||
conv = &gemma2Model{}
|
||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"gonum.org/v1/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
type gemma3nModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextModel struct {
|
||||
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
|
||||
AltupActiveIdx uint32 `json:"altup_active_idx"`
|
||||
AltupCoefClip float32 `json:"altup_coef_clip"`
|
||||
AltupCorrectScale bool `json:"altup_correct_scale"`
|
||||
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
|
||||
AltupNumInputs uint32 `json:"altup_num_inputs"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
LaurelRank uint32 `json:"laurel_rank"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
norm := distuv.Normal{Mu: 0, Sigma: 1}
|
||||
for _, v := range m.TextModel.ActivationSparsityPattern {
|
||||
if !yield(float32(norm.Quantile(float64(v)))) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
|
||||
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
|
||||
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
|
||||
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
|
||||
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
|
||||
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
|
||||
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
|
||||
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
|
||||
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, t := range m.TextModel.LayerTypes {
|
||||
if !yield(t == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
|
||||
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
|
||||
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
|
||||
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
|
||||
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
|
||||
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
|
||||
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
|
||||
kv["gemma3n.laurel_rank"] = m.TextModel.LaurelRank
|
||||
kv["gemma3n.num_kv_shared_layers"] = m.TextModel.NumKVSharedLayers
|
||||
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
|
||||
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out, ts := mergeTensors(ts,
|
||||
merge{"altup_proj.*.weight", "altup_proj.weight"},
|
||||
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
|
||||
)
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "audio_tower"),
|
||||
strings.Contains(t.Name(), "embed_audio"),
|
||||
strings.Contains(t.Name(), "vision_tower"),
|
||||
strings.Contains(t.Name(), "embed_vision"):
|
||||
// TODO: handle audio and vision towers
|
||||
continue
|
||||
case strings.Contains(t.Name(), "altup_predict_coef"),
|
||||
strings.Contains(t.Name(), "altup_correct_coef"):
|
||||
if m.TextModel.AltupCoefClip > 0 {
|
||||
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
|
||||
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
"altup.", "altup_",
|
||||
"modality_router", "router",
|
||||
"prediction_coefs", "predict_coef",
|
||||
"correction_coefs", "correct_coef",
|
||||
"correct_output_scale", "correct_scale.weight",
|
||||
"laurel.", "laurel_",
|
||||
"linear_left", "l",
|
||||
"linear_right", "r",
|
||||
"post_laurel_norm", "post_norm",
|
||||
}
|
||||
}
|
|
@ -10,4 +10,5 @@ type Config interface {
|
|||
Strings(string, ...[]string) []string
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
}
|
||||
|
|
|
@ -166,6 +166,11 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
|||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
|
|
|
@ -609,6 +609,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
|||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
case []bool:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v)
|
||||
case *array[bool]:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v.values)
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -25,6 +25,7 @@ require (
|
|||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -44,7 +45,6 @@ require (
|
|||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
|
|
|
@ -150,7 +150,7 @@ index 4cce5166..7f6617fa 100644
|
|||
llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index 3a4e72a3..831b68c0 100644
|
||||
index 3a4e72a3..db62973f 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
|
|
@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
|
|||
4 files changed, 59 insertions(+), 79 deletions(-)
|
||||
|
||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||
index c22687e4..c5948e8f 100644
|
||||
index dca22d8b..1f3a3956 100644
|
||||
--- a/src/llama-context.cpp
|
||||
+++ b/src/llama-context.cpp
|
||||
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
// find KV slot
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
|
@ -41,7 +41,7 @@ index c22687e4..c5948e8f 100644
|
|||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
|
||||
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
|
|
|
@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
|
|||
3 files changed, 192 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||
index becdae07..7a44b6cf 100644
|
||||
index 955fec59..654e2f28 100644
|
||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ index becdae07..7a44b6cf 100644
|
|||
void ggml_compute_forward_argsort(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
|
||||
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
|
||||
{
|
||||
ggml_compute_forward_argsort_f32(params, dst);
|
||||
} break;
|
||||
|
@ -195,7 +195,7 @@ index 607ded85..53b02634 100644
|
|||
+ }
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
||||
index 2d46176e..47383486 100644
|
||||
index d027271f..4abd01d7 100644
|
||||
--- a/ggml/src/ggml-cuda/cpy.cu
|
||||
+++ b/ggml/src/ggml-cuda/cpy.cu
|
||||
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
|
@ -257,7 +257,7 @@ index 2d46176e..47383486 100644
|
|||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
|
@ -266,7 +266,7 @@ index 2d46176e..47383486 100644
|
|||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Georgi Gerganov <ggerganov@gmail.com>
|
||||
Date: Thu, 19 Jun 2025 08:05:21 +0300
|
||||
Subject: [PATCH] metal : add mean kernel (#14267)
|
||||
|
||||
* metal : add mean kernel
|
||||
|
||||
ggml-ci
|
||||
|
||||
* cont : dedup implementation
|
||||
|
||||
ggml-ci
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
|
||||
2 files changed, 67 insertions(+), 14 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index ee4f2dcb..f20f5615 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
+ GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
+ case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
+ case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
+ id<MTLComputePipelineState> pipeline = nil;
|
||||
+
|
||||
+ switch (dst->op) {
|
||||
+ case GGML_OP_SUM_ROWS:
|
||||
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
+ break;
|
||||
+ case GGML_OP_MEAN:
|
||||
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
+ break;
|
||||
+ default:
|
||||
+ GGML_ABORT("fatal error");
|
||||
+ }
|
||||
+
|
||||
+ int nth = 32; // SIMD width
|
||||
+
|
||||
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
+ nth *= 2;
|
||||
+ }
|
||||
|
||||
+ nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 9cfddf45..08e8d807 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -956,31 +956,61 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
+template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
+ constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
- constant ggml_metal_kargs_sum_rows & args,
|
||||
- uint3 tpig[[thread_position_in_grid]]) {
|
||||
- int64_t i3 = tpig.z;
|
||||
- int64_t i2 = tpig.y;
|
||||
- int64_t i1 = tpig.x;
|
||||
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
+ ushort tiisg[[thread_index_in_simdgroup]],
|
||||
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
+ int64_t i3 = tgpig.z;
|
||||
+ int64_t i2 = tgpig.y;
|
||||
+ int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
+ if (sgitg == 0) {
|
||||
+ shmem_f32[tiisg] = 0.0f;
|
||||
+ }
|
||||
+
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
- float row_sum = 0;
|
||||
+ float sumf = 0;
|
||||
|
||||
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
- row_sum += src_row[i0];
|
||||
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
+ sumf += src_row[i0];
|
||||
}
|
||||
|
||||
- dst_row[0] = row_sum;
|
||||
+ sumf = simd_sum(sumf);
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ if (tiisg == 0) {
|
||||
+ shmem_f32[sgitg] = sumf;
|
||||
+ }
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ sumf = shmem_f32[tiisg];
|
||||
+ sumf = simd_sum(sumf);
|
||||
+
|
||||
+ if (tpitg.x == 0) {
|
||||
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
+ }
|
||||
}
|
||||
|
||||
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
+
|
||||
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
+
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
File diff suppressed because it is too large
Load Diff
|
@ -253,6 +253,7 @@ type Tensor interface {
|
|||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
|
@ -276,6 +277,7 @@ type Tensor interface {
|
|||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
|
@ -297,6 +299,12 @@ type Tensor interface {
|
|||
|
||||
TopK(ctx Context, k int) Tensor
|
||||
Argsort(ctx Context) Tensor
|
||||
Mean(ctx Context) Tensor
|
||||
Variance(ctx Context) Tensor
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
|
|
|
@ -297,7 +297,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
|
||||
}
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
case contains(t.Name, "cls", "output", "output_norm",
|
||||
"altup_proj", "altup_unembd_proj",
|
||||
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
|
@ -893,6 +895,13 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
|
||||
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||
panic("invalid dimension")
|
||||
|
@ -1200,6 +1209,13 @@ func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
@ -1275,3 +1291,42 @@ func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
|
|||
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
|
||||
return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
|
||||
Sqr(ctx).
|
||||
SumRows(ctx).
|
||||
Scale(ctx, 1/float64(t.Dim(0)))
|
||||
}
|
||||
|
||||
func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
|
||||
return t.Variance(ctx).Sqrt(ctx)
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||
template<bool norm>
|
||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = norm ? sum / ncols : sum;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
|
@ -2322,6 +2323,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cuda_op_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||
break;
|
||||
|
@ -3211,6 +3215,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
#include "mean.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,25 +1,9 @@
|
|||
#include "sumrows.cuh"
|
||||
|
||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col == 0) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
@ -3434,31 +3434,61 @@ kernel void kernel_neg(
|
|||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
|
|
@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
|
@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
|
@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
|
@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
|
@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
|||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
|
|
|
@ -956,31 +956,61 @@ kernel void kernel_neg(
|
|||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
package gemma3n
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*TextModel
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
// TODO: setup hybrid (local sliding window + global) cache
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3n", New)
|
||||
}
|
|
@ -0,0 +1,360 @@
|
|||
package gemma3n
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
|
||||
|
||||
*PerLayerProjector
|
||||
|
||||
AltupEmbd *nn.Linear `gguf:"altup_proj"`
|
||||
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
|
||||
|
||||
TextLayers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
// Create a tensor of a single float32 value of 1.0 to use for altup correction
|
||||
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
|
||||
|
||||
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
|
||||
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
|
||||
|
||||
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||
|
||||
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
|
||||
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
|
||||
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||
|
||||
hiddenStates := inputs.Concat(ctx, altupProj, 2)
|
||||
|
||||
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
|
||||
for i, layer := range m.TextLayers {
|
||||
if i < firstSharedKeyValue {
|
||||
cache.SetLayer(i)
|
||||
} else if m.isLocal(i) {
|
||||
cache.SetLayer(firstSharedKeyValue - 2)
|
||||
} else {
|
||||
cache.SetLayer(firstSharedKeyValue - 1)
|
||||
}
|
||||
|
||||
var layerType int
|
||||
ropeBase := m.ropeBase
|
||||
if m.isLocal(i) {
|
||||
layerType = 1
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
||||
|
||||
// inputPerLayer = inputsPerLayer[:, i, :]
|
||||
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
||||
}
|
||||
|
||||
// hiddenStates = hiddenStates[:, :, 0]
|
||||
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
|
||||
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||
|
||||
// hiddenState = hiddenStates[:, :, 1:]
|
||||
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
|
||||
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
|
||||
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||
|
||||
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.ropeBase
|
||||
if m.isLocal(layer) {
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
*nn.Embedding
|
||||
}
|
||||
|
||||
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
|
||||
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
|
||||
}
|
||||
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
*AltUp
|
||||
*Laurel
|
||||
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *TextAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
}
|
||||
|
||||
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
|
||||
predictions := d.Predict(ctx, hiddenStates, opts)
|
||||
active := opts.altupActive(ctx, predictions)
|
||||
|
||||
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
|
||||
laurel := d.Laurel.Forward(ctx, attn, opts)
|
||||
|
||||
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
|
||||
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
|
||||
attn = active.Add(ctx, attn)
|
||||
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
|
||||
|
||||
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
|
||||
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
|
||||
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
|
||||
mlp = attn.Add(ctx, mlp)
|
||||
|
||||
predictions = d.Correct(ctx, predictions, mlp, one, opts)
|
||||
active = opts.altupActive(ctx, predictions)
|
||||
if opts.altupCorrectScale {
|
||||
active = d.ScaleCorrectedOutput(ctx, active)
|
||||
}
|
||||
|
||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||
active = active.GELU(ctx)
|
||||
active = active.Mul(ctx, perLayerInput)
|
||||
|
||||
active = d.PerLayerProjection.Forward(ctx, active)
|
||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||
|
||||
// inactive := predictions[:, :, 1:]
|
||||
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
|
||||
active = inactive.Add(ctx, active)
|
||||
|
||||
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
|
||||
return predictions0.Concat(ctx, active, 2)
|
||||
}
|
||||
|
||||
type AltUp struct {
|
||||
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
|
||||
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
|
||||
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
|
||||
Router *nn.Linear `gguf:"altup_router"`
|
||||
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
|
||||
}
|
||||
|
||||
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
|
||||
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
|
||||
}
|
||||
|
||||
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
|
||||
|
||||
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
predictions := coefficients.Mulmat(ctx, hiddenStates)
|
||||
predictions = predictions.Add(ctx, hiddenStates)
|
||||
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
|
||||
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
|
||||
|
||||
modalities := a.computeRouterModalities(ctx, activated, opts)
|
||||
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Add(ctx, one)
|
||||
|
||||
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
|
||||
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
corrected := innovation.Mul(ctx, coefficients)
|
||||
corrected = corrected.Add(ctx, predictions)
|
||||
return corrected
|
||||
}
|
||||
|
||||
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
|
||||
return predictions.Mul(ctx, a.CorrectionScale)
|
||||
}
|
||||
|
||||
type Laurel struct {
|
||||
LinearLeft *nn.Linear `gguf:"laurel_l"`
|
||||
LinearRight *nn.Linear `gguf:"laurel_r"`
|
||||
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
|
||||
}
|
||||
|
||||
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
|
||||
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
|
||||
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type TextAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1., cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
|
||||
upStates := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
|
||||
if activationSparsityScale > 0 {
|
||||
mean := hiddenStates.Mean(ctx)
|
||||
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
|
||||
cutoff := mean.Add(ctx, std)
|
||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenLayers int
|
||||
hiddenSize int
|
||||
hiddenSizePerLayerInput int
|
||||
numHeads, numKVHeads int
|
||||
keyLength, valueLength int
|
||||
sharedKeyValueLayers int
|
||||
|
||||
altupActiveIndex int
|
||||
altupInputs int
|
||||
altupCorrectScale bool
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeBaseLocal float32
|
||||
ropeScale float32
|
||||
|
||||
slidingWindowPattern []bool
|
||||
activationSparsityScale []float32
|
||||
}
|
||||
|
||||
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
// t[:, :, o.altupActiveIndex]
|
||||
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(i int) bool {
|
||||
return o.slidingWindowPattern[i]
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
||||
TextOptions: TextOptions{
|
||||
hiddenLayers: int(c.Uint("block_count")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
|
||||
|
||||
altupActiveIndex: int(c.Uint("altup.active_idx")),
|
||||
altupInputs: int(c.Uint("altup.num_inputs")),
|
||||
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
||||
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
||||
},
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package models
|
|||
import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
|
Loading…
Reference in New Issue