ollama/llama/patches/0023-MXFP4.patch

1294 lines
60 KiB
Diff

From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Mon, 21 Jul 2025 12:06:13 -0700
Subject: [PATCH] MXFP4
Partial implementation of MXFP4 tensor type
---
ggml/include/ggml.h | 2 +-
ggml/src/ggml-common.h | 7 +
ggml/src/ggml-cpu/ggml-cpu-quants.h | 2 +
ggml/src/ggml-cpu/ggml-cpu.c | 5 +
ggml/src/ggml-cpu/ops.cpp | 1 +
ggml/src/ggml-cpu/vec.cpp | 90 ++++++++
ggml/src/ggml-cpu/vec.h | 2 +
ggml/src/ggml-cuda/convert.cu | 80 +++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 16 +-
ggml/src/ggml-cuda/mmvmxfp4.cu | 307 ++++++++++++++++++++++++++
ggml/src/ggml-cuda/mmvmxfp4.cuh | 9 +
ggml/src/ggml-metal/ggml-metal-impl.h | 3 +
ggml/src/ggml-metal/ggml-metal.m | 25 ++-
ggml/src/ggml-metal/ggml-metal.metal | 173 ++++++++++++++-
ggml/src/ggml-quants.c | 142 +++++++++++-
ggml/src/ggml-quants.h | 6 +
ggml/src/ggml.c | 13 +-
17 files changed, 868 insertions(+), 15 deletions(-)
create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cu
create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cuh
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index e91dedf1..873baa24 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -353,7 +353,7 @@ extern "C" {
GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3,
- // GGML_TYPE_Q4_2 = 4, support has been removed
+ GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2
// GGML_TYPE_Q4_3 = 5, support has been removed
GGML_TYPE_Q5_0 = 6,
GGML_TYPE_Q5_1 = 7,
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 086c822d..e0d71451 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -417,6 +417,13 @@ typedef struct {
} block_iq4_xs;
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+#define MXFP4 32
+typedef struct {
+ uint8_t d; // scale E8M0 float
+ uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
+
#endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL
diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ggml/src/ggml-cpu/ggml-cpu-quants.h
index e33d9d47..6a25d062 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-quants.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-quants.h
@@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
+
#ifdef __cplusplus
}
#endif
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2462d2b8..bff9c426 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
+ [GGML_TYPE_MXFP4] = {
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4,
+ .vec_dot_type = GGML_TYPE_F32,
+ .nrows = 1,
+ },
};
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 654e2f28..be0aa683 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_COUNT:
{
GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index 02d40618..ec3ec9b1 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl
}
return sum = (ggml_float)logf(sum);
}
+
+#define MXFP4 32
+typedef struct {
+ uint8_t d; // scale E8M0 float
+ uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding");
+#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}
+
+void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
+ assert(nrc == 1);
+ GGML_UNUSED(nrc);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(bs);
+ ggml_float mxfp4_table[] = MXFP4_VALS;
+
+#if defined(GGML_SIMD)
+ float sumf = 0.0f;
+ const int np = (n & ~(GGML_F32_STEP - 1));
+ const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+ GGML_F32_VEC scalev;
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+ for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64
+ for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4
+ // convert GGML_F32_ARR X elements
+ const int ib = (i + j*GGML_F32_EPR) / MXFP4;
+ const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } scale;
+ scale.as_bits = (((uint32_t)x->d) << 23);
+ scalev = GGML_F32_VEC_SET1(scale.as_value);
+ float xf[GGML_F32_EPR]= {0.f};
+ assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun");
+ for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) {
+ xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)];
+ xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4];
+ }
+
+ ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev);
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+ }
+ }
+ GGML_F32_VEC_REDUCE(sumf, sum);
+
+ // leftovers
+ for (int i = np; i < n; i+=2) {
+ const int ib = i / MXFP4;
+ const block_mxfp4 * GGML_RESTRICT x = &xx[ib];
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } scale;
+ scale.as_bits = (((uint32_t)x->d) << 23);
+ sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)];
+ sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4];
+ }
+
+
+#else // defined(GGML_SIMD)
+ const int nb = n / MXFP4;
+ assert(n % MXFP4 == 0);
+
+ int yi = 0;
+
+ const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx;
+
+ ggml_float sumf = 0.0;
+ for (int ib = 0; ib < nb; ++ib) {
+ const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0];
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } scale;
+ scale.as_bits = (((uint32_t)x->d) << 23);
+ for (int i = 0; i < MXFP4/2; ++i) {
+ sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]);
+ sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]);
+ }
+ }
+#endif
+
+ *s = sumf;
+}
diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h
index 23cbb305..7480ca08 100644
--- a/ggml/src/ggml-cpu/vec.h
+++ b/ggml/src/ggml-cpu/vec.h
@@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
+void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
+
void ggml_vec_silu_f32(const int n, float * y, const float * x);
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index c6dec427..0e016ccc 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
+// MXFP4 dequantize derived from dequantize_block_q4_0
+template<typename dst_t>
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+ const uint16_t dst_bias = 15;
+ const uint16_t dst_0p5 = 0x3800;
+ const uint16_t dst_m_bits = 10;
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ const uint64_t offset = 256*i + MXFP4*ir + 8*il;
+ dst_t * y = yy + offset;
+
+ const block_mxfp4 * x = (const block_mxfp4 *)vx + ib;
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } scale;
+ scale.as_bits = (((uint32_t)x->d) << 23);
+
+ // offset within the block 1/4 chunks (8 items)
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ uint16_t em0 = q[l] & 0x07;
+ uint16_t em1 = q[l] & 0x70;
+ // float16 values
+ iq1m_scale_t x0;
+ iq1m_scale_t x1;
+
+ x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12);
+ x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8);
+
+ // Three cases:
+ // x is normal and non-zero: Correct bias
+ if ((em0 & 0x06) != 0) {
+ x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
+ }
+ if ((em1 & 0x60) != 0) {
+ x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
+ }
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ if (em0 == 0x01) {
+ x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
+ }
+ if (em1 == 0x10) {
+ x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
+ }
+ // x is zero, do nothing
+
+ // XXX it looks correct here - but mulmat still gives bad results...
+ // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
+ // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16));
+ // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n",
+ // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16));
+
+ y[l*2] = scale.as_value * float(x0.f16);
+ y[l*2+1] = scale.as_value * float(x1.f16);
+ }
+}
+
+// derived from dequantize_row_q4_0_cuda
+template<typename dst_t>
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
default:
return nullptr;
}
@@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
return convert_unary_cont_cuda<nv_bfloat16>;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
default:
return nullptr;
}
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 28ccf4be..bb19b06e 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -21,6 +21,7 @@
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmv.cuh"
+#include "ggml-cuda/mmvmxfp4.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh"
@@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas(
const int cc = ggml_cuda_info().devices[id].cc;
- const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
+ const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4;
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
@@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
+ && src0->type != GGML_TYPE_MXFP4;
+ bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+ && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
+ } else if (use_mul_mat_vec_mxfp4) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr);
} else {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
}
@@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) {
+ ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst);
+ return;
+ }
if (ne2 == 1) {
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
@@ -3056,6 +3067,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_BF16:
+ case GGML_TYPE_MXFP4:
#ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) {
return false;
diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cu b/ggml/src/ggml-cuda/mmvmxfp4.cu
new file mode 100644
index 00000000..da62062b
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmvmxfp4.cu
@@ -0,0 +1,307 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "mmvmxfp4.cuh"
+
+// MXFP4 implementation derived from mmv.cu float32 code paths
+typedef union {
+ half f16;
+ uint16_t u16;
+} f16_t;
+
+template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
+static __global__ void mul_mat_vec_mxfp4(
+ const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+ const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
+ const int64_t row = blockIdx.x;
+ const int64_t channel_dst = blockIdx.y;
+ const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
+ const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
+ const int64_t sample_dst = blockIdx.z;
+ const int64_t sample_x = sample_dst / sample_ratio;
+ const int64_t sample_y = sample_dst;
+ const int tid = threadIdx.x;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ const uint16_t dst_bias = 15;
+ const uint16_t dst_0p5 = 0x3800;
+ const uint16_t dst_m_bits = 10;
+
+ x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
+ y += sample_y *stride_sample_y + channel_y *stride_channel_y;
+ dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
+
+ const float2 * y2 = (const float2 *) y;
+
+ extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
+ float * buf_iw = (float *) data_mmv;
+
+ if (block_size > warp_size) {
+ if (tid < warp_size) {
+ buf_iw[tid] = 0.0f;
+ }
+ __syncthreads();
+ }
+
+ float sumf = 0.0f;
+
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ int offset0 = col2 / (MXFP4/2);
+ int i = col2 % (MXFP4/2);
+ const block_mxfp4 *x2 = x+offset0;
+
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } scale;
+ scale.as_bits = (((uint32_t)x2->d) << 23);
+ uint16_t em0 = x2->qs[i] & 0x07;
+ uint16_t em1 = x2->qs[i] & 0x70;
+ // float16 values
+ f16_t x0;
+ f16_t x1;
+ x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12);
+ x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8);
+
+ // Three cases:
+ // x is normal and non-zero: Correct bias
+ if ((em0 & 0x06) != 0) {
+ x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
+ }
+ if ((em1 & 0x60) != 0) {
+ x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
+ }
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ if (em0 == 0x01) {
+ x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
+ }
+ if (em1 == 0x10) {
+ x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
+ }
+ // x is zero, do nothing
+
+ if (isnan(scale.as_value)) {
+ sumf = scale.as_value;
+ break;
+ }
+
+ const float2 tmpx = {x0.f16, x1.f16};
+ const float2 tmpy = y2[col2];
+ sumf += tmpx.x*tmpy.x*scale.as_value;
+ sumf += tmpx.y*tmpy.y*scale.as_value;
+ }
+
+ sumf = warp_reduce_sum<warp_size>(sumf);
+
+ if (block_size > warp_size) {
+ buf_iw[tid/warp_size] = sumf;
+ __syncthreads();
+ if (tid >= warp_size) {
+ return;
+ }
+ sumf = buf_iw[tid];
+ sumf = warp_reduce_sum<warp_size>(sumf);
+ }
+
+ if (tid != 0) {
+ return;
+ }
+
+ dst[row] = sumf;
+}
+
+template <typename type_acc>
+static void launch_mul_mat_vec_cuda_mxfp4(
+ const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream) {
+ GGML_ASSERT(ncols % 2 == 0);
+ // GGML_ASSERT(stride_row % 2 == 0); // TODO
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
+ int device;
+ int warp_size;
+
+ CUDA_CHECK(cudaGetDevice(&device));
+ warp_size = ggml_cuda_info().devices[device].warp_size;
+
+ int64_t block_size_best = warp_size;
+ int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
+ int64_t max_block_size = 256;
+ if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
+ max_block_size = 128;
+ }
+ for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
+ const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+ if (niter < niter_best) {
+ niter_best = niter;
+ block_size_best = block_size;
+ }
+ }
+
+ const int smem = warp_size*sizeof(float);
+ const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+ const dim3 block_dims(block_size_best, 1, 1);
+
+ switch (block_size_best) {
+ case 32: {
+ mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 64: {
+ mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 96: {
+ mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 128: {
+ mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 160: {
+ mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 192: {
+ mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 224: {
+ mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ case 256: {
+ mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+}
+
+static void mul_mat_vec_cuda_mxfp4(
+ const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ enum ggml_prec prec, cudaStream_t stream) {
+ launch_mul_mat_vec_cuda_mxfp4<float>
+ (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+}
+
+void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const size_t ts_src0 = ggml_type_size(src0->type);
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ const size_t ts_dst = ggml_type_size(dst->type);
+
+ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+ GGML_ASSERT(ne13 == ne3);
+
+ // GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+ GGML_ASSERT( nb0 == ts_dst);
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+ const float * src1_d = (const float *) src1->data;
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
+ float * dst_d = (float *) dst->data;
+
+ const int64_t stride_row = src0->nb[1] / ts_src0;
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s1 = dst->nb[1] / ts_dst;
+ const int64_t stride_channel_x = src0->nb[2] / ts_src0;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s2 = dst->nb[2] / ts_dst;
+ const int64_t stride_sample_x = src0->nb[3] / ts_src0;
+ const int64_t stride_sample_y = src1->nb[3] / ts_src1;
+ const int64_t stride_sample_dst = dst->nb[3] / ts_dst;
+ const int64_t nsamples_dst = ne3;
+ const int64_t nsamples_x = ne03;
+ const int64_t nchannels_x = ne02;
+ const int64_t nrows = ne01;
+ const int64_t ncols = ne00;
+
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+ const int64_t ncols_dst = ids ? ne2 : ne1;
+ const int64_t nchannels_y = ids ? ne11 : ne12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+ const int64_t stride_channel_y = ids ? s11 : s12;
+
+ GGML_ASSERT(ncols_dst == 1);
+
+ const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data;
+ mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream());
+}
+
+void ggml_cuda_op_mul_mat_vec_mxfp4(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ GGML_ASSERT(src1_ncols == 1);
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+ // ggml_cuda_op provides single, contiguous matrices
+ const int64_t stride_row = ne00 / MXFP4;
+ const int64_t nchannels_x = 1;
+ const int64_t nchannels_y = 1;
+ const int64_t nchannels_dst = 1;
+ const int64_t stride_channel_x = 0;
+ const int64_t stride_channel_y = 0;
+ const int64_t stride_channel_dst = 0;
+ const int64_t nsamples_x = 1;
+ const int64_t nsamples_dst = 1;
+ const int64_t stride_sample_x = 0;
+ const int64_t stride_sample_y = 0;
+ const int64_t stride_sample_dst = 0;
+
+ const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i;
+ mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+
+ GGML_UNUSED(ctx);
+ GGML_UNUSED(src1);
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src1_ddq_i);
+ GGML_UNUSED(src1_ncols);
+ GGML_UNUSED(src1_padded_row_size);
+}
diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cuh b/ggml/src/ggml-cuda/mmvmxfp4.cuh
new file mode 100644
index 00000000..a08fc780
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmvmxfp4.cuh
@@ -0,0 +1,9 @@
+#include "common.cuh"
+
+void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_vec_mxfp4(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 17eab976..938386ba 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -65,6 +65,9 @@
#define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2
+#define N_R0_MXFP4 4
+#define N_SG_MXFP4 2
+
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index ab46f6e3..d8e05a21 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
static struct ggml_backend_reg g_ggml_backend_metal_reg;
static struct ggml_backend_device g_ggml_backend_metal_device;
+
// information about a Metal device
// note: assumes single GPU device - the default one
// TODO: support multiple GPU devices
@@ -209,6 +210,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -288,6 +290,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
@@ -310,6 +313,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
@@ -334,6 +338,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
@@ -934,7 +939,7 @@ static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfl
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
-
+
//[options setFastMathEnabled:false];
metal_library = [device newLibraryWithSource:src options:options error:&error];
@@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
@@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
@@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
@@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
default: GGML_ABORT("MUL MAT-MAT not implemented");
}
@@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node(
smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nsg = N_SG_MXFP4;
+ nr0 = N_R0_MXFP4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
+ } break;
default:
{
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node(
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
default: GGML_ABORT("MUL_MAT_ID not implemented");
}
@@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node(
smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nsg = N_SG_MXFP4;
+ nr0 = N_R0_MXFP4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
+ } break;
default:
{
GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 08e8d807..69fa17de 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl(
device const char * src1,
device char * dst,
threadgroup char * shmem,
- uint3 tgpig,
- ushort tiisg,
- ushort sgitg) {
- const int nb = args.ne00/QK4_0;
+ uint3 tgpig, // Threadgroup Position in Grid
+ ushort tiisg, // Thread Index in SIMD Group
+ ushort sgitg) { // SIMD Group Index in ThreadGroup
+ const int nb = args.ne00/QK4_0; // src0->ne[0] / 32
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id(
}
}
+template <typename type4x4>
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
+ float4x4 reg_f;
+ const ushort dst_bias = 15;
+ const ushort dst_0p5 = 0x3800;
+ const ushort dst_m_bits = 10;
+ const half scale = (half)(as_type<float>(((uint32_t)xb->d) << 23));
+ // il:0 first 16, il:1 last 16
+ for (int i = 0; i < 8; i++) {
+ ushort em0 = xb->qs[il*8 + i] & 0x07;
+ ushort em1 = xb->qs[il*8 + i] & 0x70;
+ // float16 values
+ ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12);
+ ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8);
+
+ // Three cases:
+ // x is normal and non-zero: Correct bias
+ if ((em0 & 0x06) != 0) {
+ x0 = x0 + ((dst_bias - 1) << dst_m_bits);
+ }
+ if ((em1 & 0x60) != 0) {
+ x1 = x1 + ((dst_bias - 1) << dst_m_bits);
+ }
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ if (em0 == 0x01) {
+ x0 = dst_0p5 | (x0 & 0x8000);
+ }
+ if (em1 == 0x10) {
+ x1 = dst_0p5 | (x1 & 0x8000);
+ }
+ // x is zero, do nothing
+
+ if (isnan(scale)) {
+ reg_f[i/2][2*(i%2) + 0] = scale;
+ reg_f[i/2][2*(i%2) + 1] = scale;
+ } else {
+ reg_f[i/2][2*(i%2) + 0] = scale * as_type<half>(x0);
+ reg_f[i/2][2*(i%2) + 1] = scale * as_type<half>(x1);
+ }
+ }
+ reg = (type4x4) reg_f;
+}
+
#define QK_NL 16
//
@@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
+
//
// indirect matrix-matrix multiplication
//
@@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
+
//
// matrix-vector multiplication
@@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id(
sgitg);
}
+// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y
+void mul_mv_mxfp4_f32_impl(
+ ggml_metal_kargs_mul_mv args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const ushort dst_bias = 15;
+ const ushort dst_0p5 = 0x3800;
+ const ushort dst_m_bits = 10;
+ const int nr0 = N_R0_MXFP4;
+ const int nsg = N_SG_MXFP4;
+ const int nw = N_SIMDWIDTH;
+ const int nb = args.ne00/MXFP4;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr0;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const float * y = (device const float *) (src1 + offset1);
+
+ // pointers to src0 rows
+ device const block_mxfp4 * ax[nr0];
+ for (int row = 0; row < nr0; ++row) {
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0);
+ }
+
+ float yl[16]; // src1 vector cache
+ float sumf[nr0] = {0.f};
+
+ const short ix = (tiisg/2);
+ const short il = (tiisg%2)*16;
+
+ device const float * yb = y + ix*MXFP4 + il;
+
+ // each thread in a SIMD group deals with half a block.
+ for (int ib = ix; ib < nb; ib += nw/2) {
+
+#pragma unroll
+ for (short row = 0; row < nr0; row++) {
+ // Processes 16 items
+ device const block_mxfp4 * qb_curr = ax[row] + ib;
+ float d = as_type<float>(((uint32_t)(ax[row] + ib)->d) << 23);
+ // il = 0 or 16
+ device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2);
+ for (int i = 0; i < 8; ++i) {
+ ushort em0 = qs[i] & 0x07;
+ ushort em1 = qs[i] & 0x70;
+ ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12);
+ ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8);
+ // Three cases:
+ // x is normal and non-zero: Correct bias
+ if ((em0 & 0x06) != 0) {
+ x0 = x0 + ((dst_bias - 1) << dst_m_bits);
+ }
+ if ((em1 & 0x60) != 0) {
+ x1 = x1 + ((dst_bias - 1) << dst_m_bits);
+ }
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ if (em0 == 0x01) {
+ x0 = dst_0p5 | (x0 & 0x8000);
+ }
+ if (em1 == 0x10) {
+ x1 = dst_0p5 | (x1 & 0x8000);
+ }
+ // x is zero, do nothing
+ if (!isnan(d)) {
+ sumf[row] += yb[i*2] * as_type<half>(x0) * d
+ + yb[i*2+1] * as_type<half>(x1) * d;
+ } else {
+ sumf[row] = d;
+ }
+ }
+ }
+
+ yb += MXFP4 * 16;
+ }
+
+ device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
+
+ for (int row = 0; row < nr0; ++row) {
+ const float tot = simd_sum(sumf[row]);
+
+ if (tiisg == 0 && first_row + row < args.ne01) {
+ dst_f32[first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
+kernel void kernel_mul_mv_mxfp4_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
@@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_mv_mxfp4_f32_impl>>;
+
kernel void kernel_pool_2d_max_f32(
device const float * src0,
device float * dst,
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 84ec6dfe..17c308aa 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE
quantize_iq2_s(x, y, 1, k, NULL);
}
+// =============================== mxfp4 (de)-quantization
+
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
+ static const int qk = MXFP4;
+ static const uint32_t E8_BIAS = 127;
+ static const uint32_t E2_BIAS = 1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ }
+ }
+
+ const float dequant_scale = amax / 6.0f;
+ uint32_t dequant_scale_exponent = 0;
+ memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent));
+
+ // Rounding up
+ dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000;
+ // Rounding down
+ // dequant_scale_exponent = dequant_scale_exponent & 0x7F800000;
+
+ float dequant_scale_rounded = 0.0f;
+ memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded));
+ float quant_scale = 0.0f;
+ if (dequant_scale_rounded != 0.0f) {
+ quant_scale = 1.0f / dequant_scale_rounded;
+ }
+
+ y[i].d = (uint8_t)(dequant_scale_exponent >> 23);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + j*2]*quant_scale;
+ const float x1 = x[i*qk + j*2+1]*quant_scale;
+
+ uint32_t xi0 = 0;
+ uint32_t xi1 = 0;
+ memcpy(&xi0, &x0, sizeof(xi0));
+ memcpy(&xi1, &x1, sizeof(xi1));
+
+ uint32_t s0 = xi0 & 0x80000000;
+ uint32_t s1 = xi1 & 0x80000000;
+ uint32_t e0 = (xi0 >> 23) & 0xFF;
+ uint32_t e1 = (xi1 >> 23) & 0xFF;
+ uint32_t m0 = (xi0 & 0x7FFFFF);
+ uint32_t m1 = (xi1 & 0x7FFFFF);
+
+ // 0.25 <= x < 0.75 maps to 0.5, a denormal number
+ // Move implicit bit 1 at the beginning to mantissa for denormals
+ // adjusted_exponents
+ uint32_t ae0 = E8_BIAS - (e0 + 1);
+ uint32_t ae1 = E8_BIAS - (e1 + 1);
+ if (e0 < E8_BIAS) {
+ m0 = (0x400000 | (m0 >> 1)) >> ae0;
+ }
+ if (e1 < E8_BIAS) {
+ m1 = (0x400000 | (m1 >> 1)) >> ae1;
+ }
+
+ // For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
+ e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
+ e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS);
+
+ // Combine sign, exponent, and mantissa, while saturating
+ // rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
+ uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7);
+ uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7);
+ uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0);
+ uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1);
+ y[i].qs[j] = v0;
+ y[i].qs[j] |= v1 << 4;
+ }
+ }
+}
+
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+ assert(k % MXFP4 == 0);
+
+ const int nb = k / MXFP4;
+ const uint16_t dst_bias = 15;
+ const uint16_t dst_0p5 = 0x3800;
+ const uint16_t dst_m_bits = 10;
+
+ for (int i = 0; i < nb; i++) {
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } scale;
+ scale.as_bits = (((uint32_t)x[i].d) << 23);
+ for (int j = 0; j < MXFP4/2; ++j) {
+ uint16_t em0 = x[i].qs[j] & 0x07;
+ uint16_t em1 = x[i].qs[j] & 0x70;
+ // float16 values
+ uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12);
+ uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8);
+
+ // Three cases:
+ // x is normal and non-zero: Correct bias
+ if ((em0 & 0x06) != 0) {
+ x0 = x0 + ((dst_bias - 1) << dst_m_bits);
+ }
+ if ((em1 & 0x60) != 0) {
+ x1 = x1 + ((dst_bias - 1) << dst_m_bits);
+ }
+ // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
+ if (em0 == 0x01) {
+ x0 = dst_0p5 | (x0 & 0x8000);
+ }
+ if (em1 == 0x10) {
+ x1 = dst_0p5 | (x1 & 0x8000);
+ }
+ // x is zero, do nothing
+
+ if (isnan(scale.as_value)) {
+ y[i*MXFP4 + j*2] = scale.as_value;
+ y[i*MXFP4 + j*2+1] = scale.as_value;
+ } else {
+ y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value;
+ y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value;
+ }
+ }
+ }
+}
+
+
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
+}
+
// =============================== data validation
static bool validate_float(float f, size_t i) {
@@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
-
+ case GGML_TYPE_MXFP4:
+ // TODO - anything to validate?
+ break;
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
index d09173e1..2fc40f75 100644
--- a/ggml/src/ggml-quants.h
+++ b/ggml/src/ggml-quants.h
@@ -37,6 +37,8 @@ GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_
GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+
// Dequantization
GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -65,6 +67,8 @@ GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, floa
GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);
GGML_API void iq3xs_init_impl(int grid_size);
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 8a654624..0f3c9834 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -589,11 +589,13 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
},
- [4] = { // GGML_TYPE_Q4_2
- .type_name = "DEPRECATED",
- .blck_size = 0,
- .type_size = 0,
- .is_quantized = false,
+ [GGML_TYPE_MXFP4] = { // formerly deprecated GGML_TYPE_Q4_2
+ .type_name = "mxfp4",
+ .blck_size = MXFP4,
+ .type_size = sizeof(block_mxfp4),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
+ .from_float_ref = (ggml_from_float_t) quantize_row_mxfp4_ref,
},
[5] = { // GGML_TYPE_Q4_3
.type_name = "DEPRECATED",
@@ -6446,6 +6448,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_F16:
{
size_t elemsize = sizeof(ggml_fp16_t);