mirror of https://github.com/ollama/ollama.git
Vulkan Fix FA coopmat1 invalid array indexing
This commit is contained in:
parent
632948554c
commit
c14680095c
|
@ -0,0 +1,30 @@
|
||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jeff Bolz <jbolz@nvidia.com>
|
||||||
|
Date: Fri, 3 Oct 2025 04:52:46 -0500
|
||||||
|
Subject: [PATCH] vulkan: Fix FA coopmat1 invalid array indexing (#16365)
|
||||||
|
|
||||||
|
When computing sinks, the cm1 shader was looping r from 0 to Br rather than
|
||||||
|
to rows_per_thread. I must have copied this from the scalar path (where it is
|
||||||
|
correct), and somehow it wasn't causing failures on current drivers.
|
||||||
|
---
|
||||||
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 4 ++--
|
||||||
|
1 file changed, 2 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
|
||||||
|
index e76dbb4de..0507df2d8 100644
|
||||||
|
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
|
||||||
|
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
|
||||||
|
@@ -358,8 +358,8 @@ void main() {
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||||
|
- [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
|
- float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
||||||
|
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
+ float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
|
||||||
|
|
||||||
|
float ms = 1.0f;
|
||||||
|
float vs = 1.0f;
|
||||||
|
--
|
||||||
|
2.51.0.windows.1
|
||||||
|
|
|
@ -356,8 +356,8 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
|
||||||
|
|
||||||
float ms = 1.0f;
|
float ms = 1.0f;
|
||||||
float vs = 1.0f;
|
float vs = 1.0f;
|
||||||
|
|
Loading…
Reference in New Issue