mirror of https://github.com/ollama/ollama.git
170 lines
6.9 KiB
Diff
170 lines
6.9 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: Georgi Gerganov <ggerganov@gmail.com>
|
|
Date: Thu, 19 Jun 2025 08:05:21 +0300
|
|
Subject: [PATCH] metal : add mean kernel (#14267)
|
|
|
|
* metal : add mean kernel
|
|
|
|
ggml-ci
|
|
|
|
* cont : dedup implementation
|
|
|
|
ggml-ci
|
|
---
|
|
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
|
|
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
|
|
2 files changed, 67 insertions(+), 14 deletions(-)
|
|
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
|
index a9eeebc6..110c9ece 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
|
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
+ GGML_METAL_KERNEL_TYPE_MEAN,
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
case GGML_OP_LOG:
|
|
return false; // TODO: implement
|
|
case GGML_OP_SUM_ROWS:
|
|
+ case GGML_OP_MEAN:
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_GROUP_NORM:
|
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
} break;
|
|
case GGML_OP_SUM_ROWS:
|
|
+ case GGML_OP_MEAN:
|
|
{
|
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
|
|
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
+
|
|
+ switch (dst->op) {
|
|
+ case GGML_OP_SUM_ROWS:
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
+ break;
|
|
+ case GGML_OP_MEAN:
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
|
+ break;
|
|
+ default:
|
|
+ GGML_ABORT("fatal error");
|
|
+ }
|
|
+
|
|
+ int nth = 32; // SIMD width
|
|
+
|
|
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
+ nth *= 2;
|
|
+ }
|
|
|
|
+ nth = MIN(nth, ne00);
|
|
|
|
ggml_metal_kargs_sum_rows args = {
|
|
/*.ne00 =*/ ne00,
|
|
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
|
};
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
|
|
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
} break;
|
|
case GGML_OP_SOFT_MAX:
|
|
{
|
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
|
index 9cfddf45..08e8d807 100644
|
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
|
@@ -956,31 +956,61 @@ kernel void kernel_neg(
|
|
dst[tpig] = -src0[tpig];
|
|
}
|
|
|
|
+template <bool norm>
|
|
kernel void kernel_sum_rows(
|
|
+ constant ggml_metal_kargs_sum_rows & args,
|
|
device const float * src0,
|
|
device float * dst,
|
|
- constant ggml_metal_kargs_sum_rows & args,
|
|
- uint3 tpig[[thread_position_in_grid]]) {
|
|
- int64_t i3 = tpig.z;
|
|
- int64_t i2 = tpig.y;
|
|
- int64_t i1 = tpig.x;
|
|
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
+ ushort tiisg[[thread_index_in_simdgroup]],
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
+ int64_t i3 = tgpig.z;
|
|
+ int64_t i2 = tgpig.y;
|
|
+ int64_t i1 = tgpig.x;
|
|
|
|
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
return;
|
|
}
|
|
|
|
+ if (sgitg == 0) {
|
|
+ shmem_f32[tiisg] = 0.0f;
|
|
+ }
|
|
+
|
|
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
|
|
- float row_sum = 0;
|
|
+ float sumf = 0;
|
|
|
|
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
|
- row_sum += src_row[i0];
|
|
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
+ sumf += src_row[i0];
|
|
}
|
|
|
|
- dst_row[0] = row_sum;
|
|
+ sumf = simd_sum(sumf);
|
|
+
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
+
|
|
+ if (tiisg == 0) {
|
|
+ shmem_f32[sgitg] = sumf;
|
|
+ }
|
|
+
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
+
|
|
+ sumf = shmem_f32[tiisg];
|
|
+ sumf = simd_sum(sumf);
|
|
+
|
|
+ if (tpitg.x == 0) {
|
|
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
|
+ }
|
|
}
|
|
|
|
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
|
+
|
|
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
|
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
|
+
|
|
template<typename T>
|
|
kernel void kernel_soft_max(
|
|
device const char * src0,
|