mirror of https://github.com/ollama/ollama.git
1294 lines
60 KiB
Diff
1294 lines
60 KiB
Diff
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||
|
From: Daniel Hiltgen <daniel@ollama.com>
|
||
|
Date: Mon, 21 Jul 2025 12:06:13 -0700
|
||
|
Subject: [PATCH] MXFP4
|
||
|
|
||
|
Partial implementation of MXFP4 tensor type
|
||
|
---
|
||
|
ggml/include/ggml.h | 2 +-
|
||
|
ggml/src/ggml-common.h | 7 +
|
||
|
ggml/src/ggml-cpu/ggml-cpu-quants.h | 2 +
|
||
|
ggml/src/ggml-cpu/ggml-cpu.c | 5 +
|
||
|
ggml/src/ggml-cpu/ops.cpp | 1 +
|
||
|
ggml/src/ggml-cpu/vec.cpp | 90 ++++++++
|
||
|
ggml/src/ggml-cpu/vec.h | 2 +
|
||
|
ggml/src/ggml-cuda/convert.cu | 80 +++++++
|
||
|
ggml/src/ggml-cuda/ggml-cuda.cu | 16 +-
|
||
|
ggml/src/ggml-cuda/mmvmxfp4.cu | 307 ++++++++++++++++++++++++++
|
||
|
ggml/src/ggml-cuda/mmvmxfp4.cuh | 9 +
|
||
|
ggml/src/ggml-metal/ggml-metal-impl.h | 3 +
|
||
|
ggml/src/ggml-metal/ggml-metal.m | 25 ++-
|
||
|
ggml/src/ggml-metal/ggml-metal.metal | 173 ++++++++++++++-
|
||
|
ggml/src/ggml-quants.c | 142 +++++++++++-
|
||
|
ggml/src/ggml-quants.h | 6 +
|
||
|
ggml/src/ggml.c | 13 +-
|
||
|
17 files changed, 868 insertions(+), 15 deletions(-)
|
||
|
create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cu
|
||
|
create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cuh
|
||
|
|
||
|
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
|
||
|
index e91dedf1..873baa24 100644
|
||
|
--- a/ggml/include/ggml.h
|
||
|
+++ b/ggml/include/ggml.h
|
||
|
@@ -353,7 +353,7 @@ extern "C" {
|
||
|
GGML_TYPE_F16 = 1,
|
||
|
GGML_TYPE_Q4_0 = 2,
|
||
|
GGML_TYPE_Q4_1 = 3,
|
||
|
- // GGML_TYPE_Q4_2 = 4, support has been removed
|
||
|
+ GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2
|
||
|
// GGML_TYPE_Q4_3 = 5, support has been removed
|
||
|
GGML_TYPE_Q5_0 = 6,
|
||
|
GGML_TYPE_Q5_1 = 7,
|
||
|
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
|
||
|
index 086c822d..e0d71451 100644
|
||
|
--- a/ggml/src/ggml-common.h
|
||
|
+++ b/ggml/src/ggml-common.h
|
||
|
@@ -417,6 +417,13 @@ typedef struct {
|
||
|
} block_iq4_xs;
|
||
|
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
|
||
|
|
||
|
+#define MXFP4 32
|
||
|
+typedef struct {
|
||
|
+ uint8_t d; // scale E8M0 float
|
||
|
+ uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
|
||
|
+} block_mxfp4;
|
||
|
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
|
||
|
+
|
||
|
#endif // GGML_COMMON_DECL
|
||
|
#endif // GGML_COMMON_DECL
|
||
|
|
||
|
diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ggml/src/ggml-cpu/ggml-cpu-quants.h
|
||
|
index e33d9d47..6a25d062 100644
|
||
|
--- a/ggml/src/ggml-cpu/ggml-cpu-quants.h
|
||
|
+++ b/ggml/src/ggml-cpu/ggml-cpu-quants.h
|
||
|
@@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const
|
||
|
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||
|
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||
|
|
||
|
+void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
|
||
|
+
|
||
|
#ifdef __cplusplus
|
||
|
}
|
||
|
#endif
|
||
|
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
||
|
index 2462d2b8..bff9c426 100644
|
||
|
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
||
|
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
||
|
@@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||
|
.nrows = 1,
|
||
|
},
|
||
|
+ [GGML_TYPE_MXFP4] = {
|
||
|
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4,
|
||
|
+ .vec_dot_type = GGML_TYPE_F32,
|
||
|
+ .nrows = 1,
|
||
|
+ },
|
||
|
};
|
||
|
|
||
|
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
|
||
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||
|
index 654e2f28..be0aa683 100644
|
||
|
--- a/ggml/src/ggml-cpu/ops.cpp
|
||
|
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||
|
@@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp(
|
||
|
case GGML_TYPE_I32:
|
||
|
case GGML_TYPE_I64:
|
||
|
case GGML_TYPE_F64:
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
case GGML_TYPE_COUNT:
|
||
|
{
|
||
|
GGML_ABORT("fatal error");
|
||
|
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
|
||
|
index 02d40618..ec3ec9b1 100644
|
||
|
--- a/ggml/src/ggml-cpu/vec.cpp
|
||
|
+++ b/ggml/src/ggml-cpu/vec.cpp
|
||
|
@@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
|
||
|
}
|
||
|
return sum = (ggml_float)logf(sum);
|
||
|
}
|
||
|
+
|
||
|
+#define MXFP4 32
|
||
|
+typedef struct {
|
||
|
+ uint8_t d; // scale E8M0 float
|
||
|
+ uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
|
||
|
+} block_mxfp4;
|
||
|
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
|
||
|
+#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}
|
||
|
+
|
||
|
+void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
|
||
|
+ assert(nrc == 1);
|
||
|
+ GGML_UNUSED(nrc);
|
||
|
+ GGML_UNUSED(bx);
|
||
|
+ GGML_UNUSED(by);
|
||
|
+ GGML_UNUSED(bs);
|
||
|
+ ggml_float mxfp4_table[] = MXFP4_VALS;
|
||
|
+
|
||
|
+#if defined(GGML_SIMD)
|
||
|
+ float sumf = 0.0f;
|
||
|
+ const int np = (n & ~(GGML_F32_STEP - 1));
|
||
|
+ const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
|
||
|
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||
|
+
|
||
|
+ GGML_F32_VEC scalev;
|
||
|
+ GGML_F32_VEC ax[GGML_F32_ARR];
|
||
|
+ GGML_F32_VEC ay[GGML_F32_ARR];
|
||
|
+ for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64
|
||
|
+ for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4
|
||
|
+ // convert GGML_F32_ARR X elements
|
||
|
+ const int ib = (i + j*GGML_F32_EPR) / MXFP4;
|
||
|
+ const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
|
||
|
+ union {
|
||
|
+ uint32_t as_bits;
|
||
|
+ float as_value;
|
||
|
+ } scale;
|
||
|
+ scale.as_bits = (((uint32_t)x->d) << 23);
|
||
|
+ scalev = GGML_F32_VEC_SET1(scale.as_value);
|
||
|
+ float xf[GGML_F32_EPR]= {0.f};
|
||
|
+ assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun");
|
||
|
+ for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) {
|
||
|
+ xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)];
|
||
|
+ xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4];
|
||
|
+ }
|
||
|
+
|
||
|
+ ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev);
|
||
|
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
|
||
|
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
|
||
|
+ }
|
||
|
+ }
|
||
|
+ GGML_F32_VEC_REDUCE(sumf, sum);
|
||
|
+
|
||
|
+ // leftovers
|
||
|
+ for (int i = np; i < n; i+=2) {
|
||
|
+ const int ib = i / MXFP4;
|
||
|
+ const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
|
||
|
+ union {
|
||
|
+ uint32_t as_bits;
|
||
|
+ float as_value;
|
||
|
+ } scale;
|
||
|
+ scale.as_bits = (((uint32_t)x->d) << 23);
|
||
|
+ sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)];
|
||
|
+ sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4];
|
||
|
+ }
|
||
|
+
|
||
|
+
|
||
|
+#else // defined(GGML_SIMD)
|
||
|
+ const int nb = n / MXFP4;
|
||
|
+ assert(n % MXFP4 == 0);
|
||
|
+
|
||
|
+ int yi = 0;
|
||
|
+
|
||
|
+ const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
|
||
|
+
|
||
|
+ ggml_float sumf = 0.0;
|
||
|
+ for (int ib = 0; ib < nb; ++ib) {
|
||
|
+ const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0];
|
||
|
+ union {
|
||
|
+ uint32_t as_bits;
|
||
|
+ float as_value;
|
||
|
+ } scale;
|
||
|
+ scale.as_bits = (((uint32_t)x->d) << 23);
|
||
|
+ for (int i = 0; i < MXFP4/2; ++i) {
|
||
|
+ sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]);
|
||
|
+ sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]);
|
||
|
+ }
|
||
|
+ }
|
||
|
+#endif
|
||
|
+
|
||
|
+ *s = sumf;
|
||
|
+}
|
||
|
diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h
|
||
|
index 23cbb305..7480ca08 100644
|
||
|
--- a/ggml/src/ggml-cpu/vec.h
|
||
|
+++ b/ggml/src/ggml-cpu/vec.h
|
||
|
@@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||
|
void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
|
||
|
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
|
||
|
|
||
|
+void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
|
||
|
+
|
||
|
void ggml_vec_silu_f32(const int n, float * y, const float * x);
|
||
|
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
|
||
|
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
|
||
|
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
|
||
|
index c6dec427..0e016ccc 100644
|
||
|
--- a/ggml/src/ggml-cuda/convert.cu
|
||
|
+++ b/ggml/src/ggml-cuda/convert.cu
|
||
|
@@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
|
||
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||
|
}
|
||
|
|
||
|
+// MXFP4 dequantize derived from dequantize_block_q4_0
|
||
|
+template<typename dst_t>
|
||
|
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||
|
+ const uint16_t dst_bias = 15;
|
||
|
+ const uint16_t dst_0p5 = 0x3800;
|
||
|
+ const uint16_t dst_m_bits = 10;
|
||
|
+ const int64_t i = blockIdx.x;
|
||
|
+
|
||
|
+ // assume 32 threads
|
||
|
+ const int64_t tid = threadIdx.x;
|
||
|
+ const int64_t il = tid/8;
|
||
|
+ const int64_t ir = tid%8;
|
||
|
+ const int64_t ib = 8*i + ir;
|
||
|
+ if (ib >= nb32) {
|
||
|
+ return;
|
||
|
+ }
|
||
|
+
|
||
|
+ const uint64_t offset = 256*i + MXFP4*ir + 8*il;
|
||
|
+ dst_t * y = yy + offset;
|
||
|
+
|
||
|
+ const block_mxfp4 * x = (const block_mxfp4 *)vx + ib;
|
||
|
+ union {
|
||
|
+ uint32_t as_bits;
|
||
|
+ float as_value;
|
||
|
+ } scale;
|
||
|
+ scale.as_bits = (((uint32_t)x->d) << 23);
|
||
|
+
|
||
|
+ // offset within the block 1/4 chunks (8 items)
|
||
|
+ const uint8_t * q = x->qs + 4*il;
|
||
|
+
|
||
|
+ for (int l = 0; l < 4; ++l) {
|
||
|
+ uint16_t em0 = q[l] & 0x07;
|
||
|
+ uint16_t em1 = q[l] & 0x70;
|
||
|
+ // float16 values
|
||
|
+ iq1m_scale_t x0;
|
||
|
+ iq1m_scale_t x1;
|
||
|
+
|
||
|
+ x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12);
|
||
|
+ x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8);
|
||
|
+
|
||
|
+ // Three cases:
|
||
|
+ // x is normal and non-zero: Correct bias
|
||
|
+ if ((em0 & 0x06) != 0) {
|
||
|
+ x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ if ((em1 & 0x60) != 0) {
|
||
|
+ x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||
|
+ if (em0 == 0x01) {
|
||
|
+ x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
|
||
|
+ }
|
||
|
+ if (em1 == 0x10) {
|
||
|
+ x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
|
||
|
+ }
|
||
|
+ // x is zero, do nothing
|
||
|
+
|
||
|
+ // XXX it looks correct here - but mulmat still gives bad results...
|
||
|
+ // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
|
||
|
+ // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16));
|
||
|
+ // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
|
||
|
+ // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16));
|
||
|
+
|
||
|
+ y[l*2] = scale.as_value * float(x0.f16);
|
||
|
+ y[l*2+1] = scale.as_value * float(x1.f16);
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+// derived from dequantize_row_q4_0_cuda
|
||
|
+template<typename dst_t>
|
||
|
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||
|
+ const int nb32 = k / 32;
|
||
|
+ const int nb = (k + 255) / 256;
|
||
|
+ dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
||
|
+}
|
||
|
+
|
||
|
template <typename src_t, typename dst_t>
|
||
|
static __global__ void convert_unary(
|
||
|
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||
|
@@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||
|
return convert_unary_cont_cuda<float>;
|
||
|
case GGML_TYPE_BF16:
|
||
|
return convert_unary_cont_cuda<nv_bfloat16>;
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
+ return dequantize_row_mxfp4_cuda;
|
||
|
default:
|
||
|
return nullptr;
|
||
|
}
|
||
|
@@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||
|
return convert_unary_cont_cuda<half>;
|
||
|
case GGML_TYPE_BF16:
|
||
|
return convert_unary_cont_cuda<nv_bfloat16>;
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
+ return dequantize_row_mxfp4_cuda;
|
||
|
default:
|
||
|
return nullptr;
|
||
|
}
|
||
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||
|
index 28ccf4be..bb19b06e 100644
|
||
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||
|
@@ -21,6 +21,7 @@
|
||
|
#include "ggml-cuda/im2col.cuh"
|
||
|
#include "ggml-cuda/mmq.cuh"
|
||
|
#include "ggml-cuda/mmv.cuh"
|
||
|
+#include "ggml-cuda/mmvmxfp4.cuh"
|
||
|
#include "ggml-cuda/mmvq.cuh"
|
||
|
#include "ggml-cuda/norm.cuh"
|
||
|
#include "ggml-cuda/opt-step-adamw.cuh"
|
||
|
@@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||
|
|
||
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||
|
|
||
|
- const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||
|
+ const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4;
|
||
|
|
||
|
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||
|
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
|
||
|
@@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||
|
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||
|
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||
|
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||
|
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
|
||
|
+ && src0->type != GGML_TYPE_MXFP4;
|
||
|
+ bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4
|
||
|
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||
|
+ && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||
|
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||
|
|
||
|
@@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
||
|
} else if (use_mul_mat_q) {
|
||
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
|
||
|
+ } else if (use_mul_mat_vec_mxfp4) {
|
||
|
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr);
|
||
|
} else {
|
||
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
|
||
|
}
|
||
|
@@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||
|
|
||
|
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||
|
+ if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) {
|
||
|
+ ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst);
|
||
|
+ return;
|
||
|
+ }
|
||
|
if (ne2 == 1) {
|
||
|
if (ggml_is_quantized(src0->type)) {
|
||
|
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||
|
@@ -3056,6 +3067,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||
|
case GGML_TYPE_IQ4_NL:
|
||
|
case GGML_TYPE_IQ4_XS:
|
||
|
case GGML_TYPE_BF16:
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
#ifdef GGML_USE_MUSA
|
||
|
if (a->type == GGML_TYPE_Q3_K) {
|
||
|
return false;
|
||
|
diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cu b/ggml/src/ggml-cuda/mmvmxfp4.cu
|
||
|
new file mode 100644
|
||
|
index 00000000..da62062b
|
||
|
--- /dev/null
|
||
|
+++ b/ggml/src/ggml-cuda/mmvmxfp4.cu
|
||
|
@@ -0,0 +1,307 @@
|
||
|
+#include "ggml.h"
|
||
|
+#include "common.cuh"
|
||
|
+#include "mmvmxfp4.cuh"
|
||
|
+
|
||
|
+// MXFP4 implementation derived from mmv.cu float32 code paths
|
||
|
+typedef union {
|
||
|
+ half f16;
|
||
|
+ uint16_t u16;
|
||
|
+} f16_t;
|
||
|
+
|
||
|
+template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
|
||
|
+static __global__ void mul_mat_vec_mxfp4(
|
||
|
+ const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||
|
+ const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
|
||
|
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||
|
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
||
|
+ const int64_t row = blockIdx.x;
|
||
|
+ const int64_t channel_dst = blockIdx.y;
|
||
|
+ const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||
|
+ const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||
|
+ const int64_t sample_dst = blockIdx.z;
|
||
|
+ const int64_t sample_x = sample_dst / sample_ratio;
|
||
|
+ const int64_t sample_y = sample_dst;
|
||
|
+ const int tid = threadIdx.x;
|
||
|
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||
|
+
|
||
|
+ const uint16_t dst_bias = 15;
|
||
|
+ const uint16_t dst_0p5 = 0x3800;
|
||
|
+ const uint16_t dst_m_bits = 10;
|
||
|
+
|
||
|
+ x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||
|
+ y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
||
|
+ dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
||
|
+
|
||
|
+ const float2 * y2 = (const float2 *) y;
|
||
|
+
|
||
|
+ extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
|
||
|
+ float * buf_iw = (float *) data_mmv;
|
||
|
+
|
||
|
+ if (block_size > warp_size) {
|
||
|
+ if (tid < warp_size) {
|
||
|
+ buf_iw[tid] = 0.0f;
|
||
|
+ }
|
||
|
+ __syncthreads();
|
||
|
+ }
|
||
|
+
|
||
|
+ float sumf = 0.0f;
|
||
|
+
|
||
|
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||
|
+ int offset0 = col2 / (MXFP4/2);
|
||
|
+ int i = col2 % (MXFP4/2);
|
||
|
+ const block_mxfp4 *x2 = x+offset0;
|
||
|
+
|
||
|
+ union {
|
||
|
+ uint32_t as_bits;
|
||
|
+ float as_value;
|
||
|
+ } scale;
|
||
|
+ scale.as_bits = (((uint32_t)x2->d) << 23);
|
||
|
+ uint16_t em0 = x2->qs[i] & 0x07;
|
||
|
+ uint16_t em1 = x2->qs[i] & 0x70;
|
||
|
+ // float16 values
|
||
|
+ f16_t x0;
|
||
|
+ f16_t x1;
|
||
|
+ x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12);
|
||
|
+ x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8);
|
||
|
+
|
||
|
+ // Three cases:
|
||
|
+ // x is normal and non-zero: Correct bias
|
||
|
+ if ((em0 & 0x06) != 0) {
|
||
|
+ x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ if ((em1 & 0x60) != 0) {
|
||
|
+ x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||
|
+ if (em0 == 0x01) {
|
||
|
+ x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
|
||
|
+ }
|
||
|
+ if (em1 == 0x10) {
|
||
|
+ x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
|
||
|
+ }
|
||
|
+ // x is zero, do nothing
|
||
|
+
|
||
|
+ if (isnan(scale.as_value)) {
|
||
|
+ sumf = scale.as_value;
|
||
|
+ break;
|
||
|
+ }
|
||
|
+
|
||
|
+ const float2 tmpx = {x0.f16, x1.f16};
|
||
|
+ const float2 tmpy = y2[col2];
|
||
|
+ sumf += tmpx.x*tmpy.x*scale.as_value;
|
||
|
+ sumf += tmpx.y*tmpy.y*scale.as_value;
|
||
|
+ }
|
||
|
+
|
||
|
+ sumf = warp_reduce_sum<warp_size>(sumf);
|
||
|
+
|
||
|
+ if (block_size > warp_size) {
|
||
|
+ buf_iw[tid/warp_size] = sumf;
|
||
|
+ __syncthreads();
|
||
|
+ if (tid >= warp_size) {
|
||
|
+ return;
|
||
|
+ }
|
||
|
+ sumf = buf_iw[tid];
|
||
|
+ sumf = warp_reduce_sum<warp_size>(sumf);
|
||
|
+ }
|
||
|
+
|
||
|
+ if (tid != 0) {
|
||
|
+ return;
|
||
|
+ }
|
||
|
+
|
||
|
+ dst[row] = sumf;
|
||
|
+}
|
||
|
+
|
||
|
+template <typename type_acc>
|
||
|
+static void launch_mul_mat_vec_cuda_mxfp4(
|
||
|
+ const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
|
||
|
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||
|
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||
|
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||
|
+ cudaStream_t stream) {
|
||
|
+ GGML_ASSERT(ncols % 2 == 0);
|
||
|
+ // GGML_ASSERT(stride_row % 2 == 0); // TODO
|
||
|
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||
|
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||
|
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||
|
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||
|
+ int device;
|
||
|
+ int warp_size;
|
||
|
+
|
||
|
+ CUDA_CHECK(cudaGetDevice(&device));
|
||
|
+ warp_size = ggml_cuda_info().devices[device].warp_size;
|
||
|
+
|
||
|
+ int64_t block_size_best = warp_size;
|
||
|
+ int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
|
||
|
+ int64_t max_block_size = 256;
|
||
|
+ if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
|
||
|
+ max_block_size = 128;
|
||
|
+ }
|
||
|
+ for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
|
||
|
+ const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
|
||
|
+ if (niter < niter_best) {
|
||
|
+ niter_best = niter;
|
||
|
+ block_size_best = block_size;
|
||
|
+ }
|
||
|
+ }
|
||
|
+
|
||
|
+ const int smem = warp_size*sizeof(float);
|
||
|
+ const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||
|
+ const dim3 block_dims(block_size_best, 1, 1);
|
||
|
+
|
||
|
+ switch (block_size_best) {
|
||
|
+ case 32: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 64: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 96: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 128: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 160: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 192: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 224: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ case 256: {
|
||
|
+ mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||
|
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||
|
+ } break;
|
||
|
+ default: {
|
||
|
+ GGML_ABORT("fatal error");
|
||
|
+ } break;
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+static void mul_mat_vec_cuda_mxfp4(
|
||
|
+ const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
|
||
|
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||
|
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||
|
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||
|
+ enum ggml_prec prec, cudaStream_t stream) {
|
||
|
+ launch_mul_mat_vec_cuda_mxfp4<float>
|
||
|
+ (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||
|
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||
|
+}
|
||
|
+
|
||
|
+void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||
|
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||
|
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||
|
+
|
||
|
+ GGML_TENSOR_BINARY_OP_LOCALS;
|
||
|
+
|
||
|
+ const size_t ts_src0 = ggml_type_size(src0->type);
|
||
|
+ const size_t ts_src1 = ggml_type_size(src1->type);
|
||
|
+ const size_t ts_dst = ggml_type_size(dst->type);
|
||
|
+
|
||
|
+ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
|
||
|
+ GGML_ASSERT(ne13 == ne3);
|
||
|
+
|
||
|
+ // GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic
|
||
|
+ GGML_ASSERT( nb10 == ts_src1);
|
||
|
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
||
|
+ GGML_ASSERT( nb0 == ts_dst);
|
||
|
+
|
||
|
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||
|
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||
|
+
|
||
|
+ const float * src1_d = (const float *) src1->data;
|
||
|
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||
|
+ float * dst_d = (float *) dst->data;
|
||
|
+
|
||
|
+ const int64_t stride_row = src0->nb[1] / ts_src0;
|
||
|
+ const int64_t s11 = src1->nb[1] / ts_src1;
|
||
|
+ const int64_t s1 = dst->nb[1] / ts_dst;
|
||
|
+ const int64_t stride_channel_x = src0->nb[2] / ts_src0;
|
||
|
+ const int64_t s12 = src1->nb[2] / ts_src1;
|
||
|
+ const int64_t s2 = dst->nb[2] / ts_dst;
|
||
|
+ const int64_t stride_sample_x = src0->nb[3] / ts_src0;
|
||
|
+ const int64_t stride_sample_y = src1->nb[3] / ts_src1;
|
||
|
+ const int64_t stride_sample_dst = dst->nb[3] / ts_dst;
|
||
|
+ const int64_t nsamples_dst = ne3;
|
||
|
+ const int64_t nsamples_x = ne03;
|
||
|
+ const int64_t nchannels_x = ne02;
|
||
|
+ const int64_t nrows = ne01;
|
||
|
+ const int64_t ncols = ne00;
|
||
|
+
|
||
|
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||
|
+ const int64_t ncols_dst = ids ? ne2 : ne1;
|
||
|
+ const int64_t nchannels_y = ids ? ne11 : ne12;
|
||
|
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||
|
+ const int64_t stride_channel_dst = ids ? s1 : s2;
|
||
|
+ const int64_t stride_channel_y = ids ? s11 : s12;
|
||
|
+
|
||
|
+ GGML_ASSERT(ncols_dst == 1);
|
||
|
+
|
||
|
+ const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data;
|
||
|
+ mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row,
|
||
|
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||
|
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream());
|
||
|
+}
|
||
|
+
|
||
|
+void ggml_cuda_op_mul_mat_vec_mxfp4(
|
||
|
+ ggml_backend_cuda_context & ctx,
|
||
|
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||
|
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||
|
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||
|
+
|
||
|
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||
|
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||
|
+
|
||
|
+ const int64_t ne00 = src0->ne[0];
|
||
|
+ const int64_t row_diff = row_high - row_low;
|
||
|
+
|
||
|
+ GGML_ASSERT(src1_ncols == 1);
|
||
|
+
|
||
|
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||
|
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||
|
+
|
||
|
+ // ggml_cuda_op provides single, contiguous matrices
|
||
|
+ const int64_t stride_row = ne00 / MXFP4;
|
||
|
+ const int64_t nchannels_x = 1;
|
||
|
+ const int64_t nchannels_y = 1;
|
||
|
+ const int64_t nchannels_dst = 1;
|
||
|
+ const int64_t stride_channel_x = 0;
|
||
|
+ const int64_t stride_channel_y = 0;
|
||
|
+ const int64_t stride_channel_dst = 0;
|
||
|
+ const int64_t nsamples_x = 1;
|
||
|
+ const int64_t nsamples_dst = 1;
|
||
|
+ const int64_t stride_sample_x = 0;
|
||
|
+ const int64_t stride_sample_y = 0;
|
||
|
+ const int64_t stride_sample_dst = 0;
|
||
|
+
|
||
|
+ const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i;
|
||
|
+ mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||
|
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||
|
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||
|
+
|
||
|
+ GGML_UNUSED(ctx);
|
||
|
+ GGML_UNUSED(src1);
|
||
|
+ GGML_UNUSED(dst);
|
||
|
+ GGML_UNUSED(src1_ddq_i);
|
||
|
+ GGML_UNUSED(src1_ncols);
|
||
|
+ GGML_UNUSED(src1_padded_row_size);
|
||
|
+}
|
||
|
diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cuh b/ggml/src/ggml-cuda/mmvmxfp4.cuh
|
||
|
new file mode 100644
|
||
|
index 00000000..a08fc780
|
||
|
--- /dev/null
|
||
|
+++ b/ggml/src/ggml-cuda/mmvmxfp4.cuh
|
||
|
@@ -0,0 +1,9 @@
|
||
|
+#include "common.cuh"
|
||
|
+
|
||
|
+void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||
|
+
|
||
|
+void ggml_cuda_op_mul_mat_vec_mxfp4(
|
||
|
+ ggml_backend_cuda_context & ctx,
|
||
|
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||
|
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||
|
+ const int64_t src1_padded_row_size, cudaStream_t stream);
|
||
|
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||
|
index 17eab976..938386ba 100644
|
||
|
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
|
||
|
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||
|
@@ -65,6 +65,9 @@
|
||
|
#define N_R0_IQ4_XS 2
|
||
|
#define N_SG_IQ4_XS 2
|
||
|
|
||
|
+#define N_R0_MXFP4 4
|
||
|
+#define N_SG_MXFP4 2
|
||
|
+
|
||
|
// kernel argument structs
|
||
|
//
|
||
|
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||
|
index ab46f6e3..d8e05a21 100644
|
||
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||
|
@@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
||
|
static struct ggml_backend_reg g_ggml_backend_metal_reg;
|
||
|
static struct ggml_backend_device g_ggml_backend_metal_device;
|
||
|
|
||
|
+
|
||
|
// information about a Metal device
|
||
|
// note: assumes single GPU device - the default one
|
||
|
// TODO: support multiple GPU devices
|
||
|
@@ -209,6 +210,7 @@ enum ggml_metal_kernel_type {
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||
|
@@ -288,6 +290,7 @@ enum ggml_metal_kernel_type {
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
|
||
|
@@ -310,6 +313,7 @@ enum ggml_metal_kernel_type {
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||
|
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||
|
@@ -334,6 +338,7 @@ enum ggml_metal_kernel_type {
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
||
|
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
|
||
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
||
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||
|
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
||
|
@@ -934,7 +939,7 @@ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfl
|
||
|
|
||
|
MTLCompileOptions * options = [MTLCompileOptions new];
|
||
|
options.preprocessorMacros = prep;
|
||
|
-
|
||
|
+
|
||
|
//[options setFastMathEnabled:false];
|
||
|
|
||
|
metal_library = [device newLibraryWithSource:src options:options error:&error];
|
||
|
@@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||
|
@@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
||
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
||
|
@@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||
|
@@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
||
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
||
|
@@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node(
|
||
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
||
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||
|
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||
|
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||
|
}
|
||
|
|
||
|
@@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node(
|
||
|
smem = 32*sizeof(float);
|
||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
||
|
} break;
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
+ {
|
||
|
+ nsg = N_SG_MXFP4;
|
||
|
+ nr0 = N_R0_MXFP4;
|
||
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
|
||
|
+ } break;
|
||
|
default:
|
||
|
{
|
||
|
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
||
|
@@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node(
|
||
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
||
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
||
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
||
|
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
|
||
|
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||
|
}
|
||
|
|
||
|
@@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node(
|
||
|
smem = 32*sizeof(float);
|
||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
||
|
} break;
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
+ {
|
||
|
+ nsg = N_SG_MXFP4;
|
||
|
+ nr0 = N_R0_MXFP4;
|
||
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
|
||
|
+ } break;
|
||
|
default:
|
||
|
{
|
||
|
GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
||
|
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||
|
index 08e8d807..69fa17de 100644
|
||
|
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||
|
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||
|
@@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl(
|
||
|
device const char * src1,
|
||
|
device char * dst,
|
||
|
threadgroup char * shmem,
|
||
|
- uint3 tgpig,
|
||
|
- ushort tiisg,
|
||
|
- ushort sgitg) {
|
||
|
- const int nb = args.ne00/QK4_0;
|
||
|
+ uint3 tgpig, // Threadgroup Position in Grid
|
||
|
+ ushort tiisg, // Thread Index in SIMD Group
|
||
|
+ ushort sgitg) { // SIMD Group Index in ThreadGroup
|
||
|
+ const int nb = args.ne00/QK4_0; // src0->ne[0] / 32
|
||
|
|
||
|
const int r0 = tgpig.x;
|
||
|
const int r1 = tgpig.y;
|
||
|
const int im = tgpig.z;
|
||
|
|
||
|
- const int first_row = (r0 * nsg + sgitg) * nr0;
|
||
|
+ const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4
|
||
|
|
||
|
const uint i12 = im%args.ne12;
|
||
|
const uint i13 = im/args.ne12;
|
||
|
@@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id(
|
||
|
}
|
||
|
}
|
||
|
|
||
|
+template <typename type4x4>
|
||
|
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
||
|
+ float4x4 reg_f;
|
||
|
+ const ushort dst_bias = 15;
|
||
|
+ const ushort dst_0p5 = 0x3800;
|
||
|
+ const ushort dst_m_bits = 10;
|
||
|
+ const half scale = (half)(as_type<float>(((uint32_t)xb->d) << 23));
|
||
|
+ // il:0 first 16, il:1 last 16
|
||
|
+ for (int i = 0; i < 8; i++) {
|
||
|
+ ushort em0 = xb->qs[il*8 + i] & 0x07;
|
||
|
+ ushort em1 = xb->qs[il*8 + i] & 0x70;
|
||
|
+ // float16 values
|
||
|
+ ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12);
|
||
|
+ ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8);
|
||
|
+
|
||
|
+ // Three cases:
|
||
|
+ // x is normal and non-zero: Correct bias
|
||
|
+ if ((em0 & 0x06) != 0) {
|
||
|
+ x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ if ((em1 & 0x60) != 0) {
|
||
|
+ x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||
|
+ if (em0 == 0x01) {
|
||
|
+ x0 = dst_0p5 | (x0 & 0x8000);
|
||
|
+ }
|
||
|
+ if (em1 == 0x10) {
|
||
|
+ x1 = dst_0p5 | (x1 & 0x8000);
|
||
|
+ }
|
||
|
+ // x is zero, do nothing
|
||
|
+
|
||
|
+ if (isnan(scale)) {
|
||
|
+ reg_f[i/2][2*(i%2) + 0] = scale;
|
||
|
+ reg_f[i/2][2*(i%2) + 1] = scale;
|
||
|
+ } else {
|
||
|
+ reg_f[i/2][2*(i%2) + 0] = scale * as_type<half>(x0);
|
||
|
+ reg_f[i/2][2*(i%2) + 1] = scale * as_type<half>(x1);
|
||
|
+ }
|
||
|
+ }
|
||
|
+ reg = (type4x4) reg_f;
|
||
|
+}
|
||
|
+
|
||
|
#define QK_NL 16
|
||
|
|
||
|
//
|
||
|
@@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m
|
||
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||
|
|
||
|
+template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||
|
+
|
||
|
//
|
||
|
// indirect matrix-matrix multiplication
|
||
|
//
|
||
|
@@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m
|
||
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||
|
|
||
|
+template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||
|
+
|
||
|
|
||
|
//
|
||
|
// matrix-vector multiplication
|
||
|
@@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id(
|
||
|
sgitg);
|
||
|
}
|
||
|
|
||
|
+// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y
|
||
|
+void mul_mv_mxfp4_f32_impl(
|
||
|
+ ggml_metal_kargs_mul_mv args,
|
||
|
+ device const char * src0,
|
||
|
+ device const char * src1,
|
||
|
+ device char * dst,
|
||
|
+ threadgroup char * shmem,
|
||
|
+ uint3 tgpig,
|
||
|
+ ushort tiisg,
|
||
|
+ ushort sgitg) {
|
||
|
+ const ushort dst_bias = 15;
|
||
|
+ const ushort dst_0p5 = 0x3800;
|
||
|
+ const ushort dst_m_bits = 10;
|
||
|
+ const int nr0 = N_R0_MXFP4;
|
||
|
+ const int nsg = N_SG_MXFP4;
|
||
|
+ const int nw = N_SIMDWIDTH;
|
||
|
+ const int nb = args.ne00/MXFP4;
|
||
|
+
|
||
|
+ const int r0 = tgpig.x;
|
||
|
+ const int r1 = tgpig.y;
|
||
|
+ const int im = tgpig.z;
|
||
|
+
|
||
|
+ const int first_row = (r0 * nsg + sgitg) * nr0;
|
||
|
+
|
||
|
+ const uint i12 = im%args.ne12;
|
||
|
+ const uint i13 = im/args.ne12;
|
||
|
+
|
||
|
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||
|
+
|
||
|
+ device const float * y = (device const float *) (src1 + offset1);
|
||
|
+
|
||
|
+ // pointers to src0 rows
|
||
|
+ device const block_mxfp4 * ax[nr0];
|
||
|
+ for (int row = 0; row < nr0; ++row) {
|
||
|
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||
|
+
|
||
|
+ ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0);
|
||
|
+ }
|
||
|
+
|
||
|
+ float yl[16]; // src1 vector cache
|
||
|
+ float sumf[nr0] = {0.f};
|
||
|
+
|
||
|
+ const short ix = (tiisg/2);
|
||
|
+ const short il = (tiisg%2)*16;
|
||
|
+
|
||
|
+ device const float * yb = y + ix*MXFP4 + il;
|
||
|
+
|
||
|
+ // each thread in a SIMD group deals with half a block.
|
||
|
+ for (int ib = ix; ib < nb; ib += nw/2) {
|
||
|
+
|
||
|
+#pragma unroll
|
||
|
+ for (short row = 0; row < nr0; row++) {
|
||
|
+ // Processes 16 items
|
||
|
+ device const block_mxfp4 * qb_curr = ax[row] + ib;
|
||
|
+ float d = as_type<float>(((uint32_t)(ax[row] + ib)->d) << 23);
|
||
|
+ // il = 0 or 16
|
||
|
+ device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2);
|
||
|
+ for (int i = 0; i < 8; ++i) {
|
||
|
+ ushort em0 = qs[i] & 0x07;
|
||
|
+ ushort em1 = qs[i] & 0x70;
|
||
|
+ ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12);
|
||
|
+ ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8);
|
||
|
+ // Three cases:
|
||
|
+ // x is normal and non-zero: Correct bias
|
||
|
+ if ((em0 & 0x06) != 0) {
|
||
|
+ x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ if ((em1 & 0x60) != 0) {
|
||
|
+ x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||
|
+ if (em0 == 0x01) {
|
||
|
+ x0 = dst_0p5 | (x0 & 0x8000);
|
||
|
+ }
|
||
|
+ if (em1 == 0x10) {
|
||
|
+ x1 = dst_0p5 | (x1 & 0x8000);
|
||
|
+ }
|
||
|
+ // x is zero, do nothing
|
||
|
+ if (!isnan(d)) {
|
||
|
+ sumf[row] += yb[i*2] * as_type<half>(x0) * d
|
||
|
+ + yb[i*2+1] * as_type<half>(x1) * d;
|
||
|
+ } else {
|
||
|
+ sumf[row] = d;
|
||
|
+ }
|
||
|
+ }
|
||
|
+ }
|
||
|
+
|
||
|
+ yb += MXFP4 * 16;
|
||
|
+ }
|
||
|
+
|
||
|
+ device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||
|
+
|
||
|
+ for (int row = 0; row < nr0; ++row) {
|
||
|
+ const float tot = simd_sum(sumf[row]);
|
||
|
+
|
||
|
+ if (tiisg == 0 && first_row + row < args.ne01) {
|
||
|
+ dst_f32[first_row + row] = tot;
|
||
|
+ }
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
|
||
|
+kernel void kernel_mul_mv_mxfp4_f32(
|
||
|
+ constant ggml_metal_kargs_mul_mv & args,
|
||
|
+ device const char * src0,
|
||
|
+ device const char * src1,
|
||
|
+ device char * dst,
|
||
|
+ threadgroup char * shmem [[threadgroup(0)]],
|
||
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
||
|
+ ushort tiisg[[thread_index_in_simdgroup]],
|
||
|
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||
|
+ mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||
|
+}
|
||
|
+
|
||
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
||
|
|
||
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||
|
@@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
|
||
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
|
||
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
|
||
|
|
||
|
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_mv_mxfp4_f32_impl>>;
|
||
|
+
|
||
|
kernel void kernel_pool_2d_max_f32(
|
||
|
device const float * src0,
|
||
|
device float * dst,
|
||
|
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
|
||
|
index 84ec6dfe..17c308aa 100644
|
||
|
--- a/ggml/src/ggml-quants.c
|
||
|
+++ b/ggml/src/ggml-quants.c
|
||
|
@@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE
|
||
|
quantize_iq2_s(x, y, 1, k, NULL);
|
||
|
}
|
||
|
|
||
|
+// =============================== mxfp4 (de)-quantization
|
||
|
+
|
||
|
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
|
||
|
+ static const int qk = MXFP4;
|
||
|
+ static const uint32_t E8_BIAS = 127;
|
||
|
+ static const uint32_t E2_BIAS = 1;
|
||
|
+
|
||
|
+ assert(k % qk == 0);
|
||
|
+
|
||
|
+ const int nb = k / qk;
|
||
|
+
|
||
|
+ for (int i = 0; i < nb; i++) {
|
||
|
+ float amax = 0.0f; // absolute max
|
||
|
+
|
||
|
+ for (int j = 0; j < qk; j++) {
|
||
|
+ const float v = x[i*qk + j];
|
||
|
+ if (amax < fabsf(v)) {
|
||
|
+ amax = fabsf(v);
|
||
|
+ }
|
||
|
+ }
|
||
|
+
|
||
|
+ const float dequant_scale = amax / 6.0f;
|
||
|
+ uint32_t dequant_scale_exponent = 0;
|
||
|
+ memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent));
|
||
|
+
|
||
|
+ // Rounding up
|
||
|
+ dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000;
|
||
|
+ // Rounding down
|
||
|
+ // dequant_scale_exponent = dequant_scale_exponent & 0x7F800000;
|
||
|
+
|
||
|
+ float dequant_scale_rounded = 0.0f;
|
||
|
+ memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded));
|
||
|
+ float quant_scale = 0.0f;
|
||
|
+ if (dequant_scale_rounded != 0.0f) {
|
||
|
+ quant_scale = 1.0f / dequant_scale_rounded;
|
||
|
+ }
|
||
|
+
|
||
|
+ y[i].d = (uint8_t)(dequant_scale_exponent >> 23);
|
||
|
+
|
||
|
+ for (int j = 0; j < qk/2; ++j) {
|
||
|
+ const float x0 = x[i*qk + j*2]*quant_scale;
|
||
|
+ const float x1 = x[i*qk + j*2+1]*quant_scale;
|
||
|
+
|
||
|
+ uint32_t xi0 = 0;
|
||
|
+ uint32_t xi1 = 0;
|
||
|
+ memcpy(&xi0, &x0, sizeof(xi0));
|
||
|
+ memcpy(&xi1, &x1, sizeof(xi1));
|
||
|
+
|
||
|
+ uint32_t s0 = xi0 & 0x80000000;
|
||
|
+ uint32_t s1 = xi1 & 0x80000000;
|
||
|
+ uint32_t e0 = (xi0 >> 23) & 0xFF;
|
||
|
+ uint32_t e1 = (xi1 >> 23) & 0xFF;
|
||
|
+ uint32_t m0 = (xi0 & 0x7FFFFF);
|
||
|
+ uint32_t m1 = (xi1 & 0x7FFFFF);
|
||
|
+
|
||
|
+ // 0.25 <= x < 0.75 maps to 0.5, a denormal number
|
||
|
+ // Move implicit bit 1 at the beginning to mantissa for denormals
|
||
|
+ // adjusted_exponents
|
||
|
+ uint32_t ae0 = E8_BIAS - (e0 + 1);
|
||
|
+ uint32_t ae1 = E8_BIAS - (e1 + 1);
|
||
|
+ if (e0 < E8_BIAS) {
|
||
|
+ m0 = (0x400000 | (m0 >> 1)) >> ae0;
|
||
|
+ }
|
||
|
+ if (e1 < E8_BIAS) {
|
||
|
+ m1 = (0x400000 | (m1 >> 1)) >> ae1;
|
||
|
+ }
|
||
|
+
|
||
|
+ // For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
||
|
+ e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
|
||
|
+ e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
|
||
|
+
|
||
|
+ // Combine sign, exponent, and mantissa, while saturating
|
||
|
+ // rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
|
||
|
+ uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7);
|
||
|
+ uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7);
|
||
|
+ uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0);
|
||
|
+ uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1);
|
||
|
+ y[i].qs[j] = v0;
|
||
|
+ y[i].qs[j] |= v1 << 4;
|
||
|
+ }
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||
|
+ assert(k % MXFP4 == 0);
|
||
|
+
|
||
|
+ const int nb = k / MXFP4;
|
||
|
+ const uint16_t dst_bias = 15;
|
||
|
+ const uint16_t dst_0p5 = 0x3800;
|
||
|
+ const uint16_t dst_m_bits = 10;
|
||
|
+
|
||
|
+ for (int i = 0; i < nb; i++) {
|
||
|
+ union {
|
||
|
+ uint32_t as_bits;
|
||
|
+ float as_value;
|
||
|
+ } scale;
|
||
|
+ scale.as_bits = (((uint32_t)x[i].d) << 23);
|
||
|
+ for (int j = 0; j < MXFP4/2; ++j) {
|
||
|
+ uint16_t em0 = x[i].qs[j] & 0x07;
|
||
|
+ uint16_t em1 = x[i].qs[j] & 0x70;
|
||
|
+ // float16 values
|
||
|
+ uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12);
|
||
|
+ uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8);
|
||
|
+
|
||
|
+ // Three cases:
|
||
|
+ // x is normal and non-zero: Correct bias
|
||
|
+ if ((em0 & 0x06) != 0) {
|
||
|
+ x0 = x0 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ if ((em1 & 0x60) != 0) {
|
||
|
+ x1 = x1 + ((dst_bias - 1) << dst_m_bits);
|
||
|
+ }
|
||
|
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||
|
+ if (em0 == 0x01) {
|
||
|
+ x0 = dst_0p5 | (x0 & 0x8000);
|
||
|
+ }
|
||
|
+ if (em1 == 0x10) {
|
||
|
+ x1 = dst_0p5 | (x1 & 0x8000);
|
||
|
+ }
|
||
|
+ // x is zero, do nothing
|
||
|
+
|
||
|
+ if (isnan(scale.as_value)) {
|
||
|
+ y[i*MXFP4 + j*2] = scale.as_value;
|
||
|
+ y[i*MXFP4 + j*2+1] = scale.as_value;
|
||
|
+ } else {
|
||
|
+ y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value;
|
||
|
+ y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value;
|
||
|
+ }
|
||
|
+ }
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+
|
||
|
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||
|
+ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
||
|
+ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
||
|
+}
|
||
|
+
|
||
|
// =============================== data validation
|
||
|
|
||
|
static bool validate_float(float f, size_t i) {
|
||
|
@@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||
|
{
|
||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||
|
} break;
|
||
|
-
|
||
|
+ case GGML_TYPE_MXFP4:
|
||
|
+ // TODO - anything to validate?
|
||
|
+ break;
|
||
|
case GGML_TYPE_I8:
|
||
|
case GGML_TYPE_I16:
|
||
|
case GGML_TYPE_I32:
|
||
|
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
|
||
|
index d09173e1..2fc40f75 100644
|
||
|
--- a/ggml/src/ggml-quants.h
|
||
|
+++ b/ggml/src/ggml-quants.h
|
||
|
@@ -37,6 +37,8 @@ GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_
|
||
|
GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
||
|
GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||
|
|
||
|
+GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
||
|
+
|
||
|
// Dequantization
|
||
|
GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||
|
GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||
|
@@ -65,6 +67,8 @@ GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, floa
|
||
|
GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||
|
GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||
|
|
||
|
+GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||
|
+
|
||
|
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||
|
GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||
|
GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||
|
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
|
||
|
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||
|
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||
|
|
||
|
+GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||
|
+
|
||
|
GGML_API void iq2xs_init_impl(enum ggml_type type);
|
||
|
GGML_API void iq2xs_free_impl(enum ggml_type type);
|
||
|
GGML_API void iq3xs_init_impl(int grid_size);
|
||
|
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
|
||
|
index 8a654624..0f3c9834 100644
|
||
|
--- a/ggml/src/ggml.c
|
||
|
+++ b/ggml/src/ggml.c
|
||
|
@@ -589,11 +589,13 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
||
|
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
|
||
|
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
|
||
|
},
|
||
|
- [4] = { // GGML_TYPE_Q4_2
|
||
|
- .type_name = "DEPRECATED",
|
||
|
- .blck_size = 0,
|
||
|
- .type_size = 0,
|
||
|
- .is_quantized = false,
|
||
|
+ [GGML_TYPE_MXFP4] = { // formerly deprecated GGML_TYPE_Q4_2
|
||
|
+ .type_name = "mxfp4",
|
||
|
+ .blck_size = MXFP4,
|
||
|
+ .type_size = sizeof(block_mxfp4),
|
||
|
+ .is_quantized = true,
|
||
|
+ .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
|
||
|
+ .from_float_ref = (ggml_from_float_t) quantize_row_mxfp4_ref,
|
||
|
},
|
||
|
[5] = { // GGML_TYPE_Q4_3
|
||
|
.type_name = "DEPRECATED",
|
||
|
@@ -6446,6 +6448,7 @@ size_t ggml_quantize_chunk(
|
||
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||
|
+ case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||
|
case GGML_TYPE_F16:
|
||
|
{
|
||
|
size_t elemsize = sizeof(ggml_fp16_t);
|