diff --git a/llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch b/llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch new file mode 100644 index 000000000..af4085440 --- /dev/null +++ b/llama/patches/0029-vulkan-Fix-FA-coopmat1-invalid-array-indexing.patch @@ -0,0 +1,30 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Jeff Bolz +Date: Fri, 3 Oct 2025 04:52:46 -0500 +Subject: [PATCH] vulkan: Fix FA coopmat1 invalid array indexing (#16365) + +When computing sinks, the cm1 shader was looping r from 0 to Br rather than +to rows_per_thread. I must have copied this from the scalar path (where it is +correct), and somehow it wasn't causing failures on current drivers. +--- + ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +index e76dbb4de..0507df2d8 100644 +--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp ++++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +@@ -358,8 +358,8 @@ void main() { + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { +- [[unroll]] for (uint32_t r = 0; r < Br; ++r) { +- float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); ++ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { ++ float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; +-- +2.51.0.windows.1 + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ddb1246e0..ef1ce0503 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -356,8 +356,8 @@ void main() { } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f;