This commit is contained in:
Thomas Stocker 2025-10-07 21:30:15 +00:00 committed by GitHub
commit 7abbd6c876
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
157 changed files with 29376 additions and 15 deletions

View File

@ -52,6 +52,12 @@ jobs:
container: rocm/dev-ubuntu-22.04:6.1.2
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan
container: ubuntu:22.04
extra-packages: >
mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make
runs-on: linux
container: ${{ matrix.container }}
steps:
@ -59,7 +65,19 @@ jobs:
- run: |
[ -n "${{ matrix.container }}" ] || sudo=sudo
$sudo apt-get update
# Add LunarG Vulkan SDK apt repo for Ubuntu 22.04
if [ "${{ matrix.preset }}" = "Vulkan" ]; then
$sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg
# Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY
echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null
$sudo apt-get update
fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
fi
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/cache@v4
@ -92,18 +110,21 @@ jobs:
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.313.2/windows/vulkansdk-windows-X64-1.4.313.2.exe
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm'
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: matrix.preset == 'CUDA'
name: Install CUDA ${{ matrix.cuda-version }}
@ -133,6 +154,18 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:

View File

@ -139,3 +139,15 @@ if(CMAKE_HIP_COMPILER)
endforeach()
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()

View File

@ -70,6 +70,10 @@
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
},
{
"name": "Vulkan",
"inherits": [ "Default" ]
}
],
"buildPresets": [
@ -122,6 +126,11 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 6"
},
{
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
}
]
}

View File

@ -7,6 +7,7 @@ ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
ARG VULKANVERSION=1.4.321.1
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
@ -17,6 +18,16 @@ RUN yum install -y yum-utils \
&& dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
@ -106,6 +117,13 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
COPY go.mod go.sum .
@ -123,7 +141,8 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@ -136,12 +155,13 @@ FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin

View File

@ -72,6 +72,7 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
} else {
info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor)
}
// TODO any special processing of Vulkan devices?
resp = append(resp, info)
}
if len(resp) == 0 {
@ -99,7 +100,16 @@ func (l GpuInfoList) GetVisibleDevicesEnv() []string {
if len(l) == 0 {
return nil
}
return []string{rocmGetVisibleDevicesEnv(l)}
res := []string{}
envVar := rocmGetVisibleDevicesEnv(l)
if envVar != "" {
res = append(res, envVar)
}
envVar = vkGetVisibleDevicesEnv(l)
if envVar != "" {
res = append(res, envVar)
}
return res
}
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
@ -129,6 +139,25 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
return envVar + strings.Join(ids, ",")
}
func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "Vulkan" {
continue
}
if info.filterID != "" {
ids = append(ids, info.filterID)
} else {
ids = append(ids, info.ID)
}
}
if len(ids) == 0 {
return ""
}
envVar := "GGML_VK_VISIBLE_DEVICES="
return envVar + strings.Join(ids, ",")
}
// GetSystemInfo returns the last cached state of the GPUs on the system
func GetSystemInfo() SystemInfo {
deviceMu.Lock()

View File

@ -84,6 +84,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
// are enumerated, but not actually supported.
// We run this in serial to avoid potentially initializing a GPU multiple
// times concurrently leading to memory contention
// TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs
for dir := range libDirs {
var dirs []string
if dir == "" {
@ -127,13 +128,17 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
} else {
envVar = "ROCR_VISIBLE_DEVICES"
}
} else {
} else if devices[i].Library == "CUDA" {
envVar = "CUDA_VISIBLE_DEVICES"
} else if devices[i].Library == "Vulkan" {
envVar = "GGML_VK_VISIBLE_DEVICES"
} else {
slog.Error("Unknown Library:" + devices[i].Library)
}
extraEnvs := []string{
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
envVar + "=" + devices[i].ID, // Filter to just this one GPU
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
envVar + "=" + devices[i].FilteredID, // Filter to just this one GPU
}
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
needsDelete[i] = true
@ -153,6 +158,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
wg.Wait()
logutil.Trace("supported GPU library combinations", "supported", supported)
filterOutVulkanThatAreSupportedByOtherGPU(needsDelete)
// Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible
filterOverlapByLibrary(supported, needsDelete)
@ -174,7 +181,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
}
// Now filter out any overlap with different libraries (favor CUDA/ROCm over others)
// Now filter out any overlap with different libraries (favor CUDA/HIP over others)
for i := 0; i < len(devices); i++ {
for j := i + 1; j < len(devices); j++ {
// For this pass, we only drop exact duplicates
@ -336,6 +343,37 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
return devices
}
func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) {
// Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion
for i := range devices {
if devices[i].Library != "Vulkan" || needsDelete[i] {
continue
}
if devices[i].PCIID == "" {
continue
}
for j := range devices {
if i == j {
continue
}
if devices[j].PCIID == "" {
continue
}
if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] {
needsDelete[i] = true
slog.Debug("dropping Vulkan duplicate by PCI ID",
"vulkan_id", devices[i].ID,
"vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1],
"pci_id", devices[i].PCIID,
"kept_library", devices[j].Library,
"kept_id", devices[j].ID,
)
break
}
}
}
}
func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) {
// For multi-GPU systems, use the newest version that supports all the GPUs
for _, byLibDirs := range supported {
@ -441,6 +479,7 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
}
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
pathNeeded := true
@ -498,6 +537,14 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
}
}
logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices)
// Enumerate returned devices starting at 0 per library and assign the per-library index as FilteredID
libCounts := make(map[string]int)
for i := range devices {
lib := devices[i].Library
devices[i].FilteredID = strconv.Itoa(libCounts[lib])
libCounts[lib]++
}
return devices
}

View File

@ -37,7 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
UnreliableFreeMemory bool
// GPU information
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
Compute string `json:"compute"` // Compute Capability or gfx
@ -174,7 +174,8 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7) ||
gpu.Library == "ROCm"
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false

View File

@ -217,6 +217,7 @@ var (
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES")
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
)
@ -307,6 +308,7 @@ func AsMap() map[string]EnvVar {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"}
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"}
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}

View File

@ -69,7 +69,9 @@ func EnumerateGPUs() []ml.DeviceID {
for i := range C.ggml_backend_dev_count() {
device := C.ggml_backend_dev_get(i)
if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
switch C.ggml_backend_dev_type(device) {
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(device, &props)
ids = append(ids, ml.DeviceID{

View File

@ -0,0 +1,95 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Xiaodong Ye <xiaodong.ye@mthreads.com>
Date: Mon, 18 Aug 2025 12:48:07 +0800
Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++
1 file changed, 37 insertions(+)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 061cd078..adea7783 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
snprintf(description, description_size, "%s", props.deviceName.data());
}
+static std::string ggml_vk_get_device_id(int device) {
+ ggml_vk_instance_init();
+
+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ vk::PhysicalDeviceProperties2 props;
+ vk::PhysicalDeviceIDProperties deviceIDProps;
+ props.pNext = &deviceIDProps;
+ devices[device].getProperties2(&props);
+
+ const auto& uuid = deviceIDProps.deviceUUID;
+ char id[64];
+ snprintf(id, sizeof(id),
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+ uuid[0], uuid[1], uuid[2], uuid[3],
+ uuid[4], uuid[5],
+ uuid[6], uuid[7],
+ uuid[8], uuid[9],
+ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15]
+ );
+ return std::string(id);
+}
+
// backend interface
#define UNUSED GGML_UNUSED
@@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
ggml_vk_get_device_description(dev_idx, description, description_size);
}
+std::string ggml_backend_vk_get_device_id(int device) {
+ GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+ int dev_idx = vk_instance.device_indices[device];
+ return ggml_vk_get_device_id(dev_idx);
+}
+
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context {
std::string description;
bool is_integrated_gpu;
std::string pci_bus_id;
+ std::string id;
};
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
return ctx->description.c_str();
}
+static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ return ctx->id.c_str();
+}
+
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
ggml_backend_vk_get_device_memory(ctx->device, free, total);
@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev);
+ props->id = ggml_backend_vk_device_get_id(dev);
props->type = ggml_backend_vk_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+ ctx->id = ggml_backend_vk_get_device_id(i);
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
--
2.51.0

View File

@ -0,0 +1,253 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri Sep 5 08:25:03 2025 -0700
Subject: [PATCH] Vulkan PCI and Memory
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++-----
1 file changed, 145 insertions(+), 31 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index adea7783..fb7204ce 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) {
return ggml_vk_get_device_id(dev_idx);
}
-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
- GGML_ASSERT(device < (int) vk_instance.device_indices.size());
- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
+//////////////////////////
+
+struct ggml_backend_vk_device_context {
+ size_t device;
+ std::string name;
+ std::string description;
+ bool is_integrated_gpu;
+ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function)
+ std::string pci_id;
+ std::string id;
+ std::string uuid;
+ int major;
+ int minor;
+ int driver_major;
+ int driver_minor;
+ int pci_bus_id;
+ int pci_device_id;
+ int pci_domain_id;
+};
+
+void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) {
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size());
+
+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
- vk::PhysicalDeviceMemoryProperties2 memprops = {};
- bool membudget_supported = vk_instance.device_supports_membudget[device];
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+ vk::PhysicalDeviceProperties2 props2;
+ vkdev.getProperties2(&props2);
- if (membudget_supported) {
- memprops.pNext = &budgetprops;
+ if (!ctx->is_integrated_gpu)
+ {
+ // Use vendor specific management libraries for best VRAM reporting if available
+ switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
+ ggml_hip_mgmt_release();
+ }
+ break;
+ case VK_VENDOR_ID_NVIDIA:
+ if (ggml_nvml_init() == 0) {
+ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+ ggml_nvml_release();
+ return;
+ }
+ ggml_nvml_release();
+ }
+ break;
+ }
}
- vkdev.getMemoryProperties2(&memprops);
+ // else fallback to memory budget if supported
- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
+ *total = 0;
+ *free = 0;
+ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
+ vk::PhysicalDeviceMemoryProperties2 memprops2;
+ memprops2.pNext = &mem_budget_props;
+ vkdev.getMemoryProperties2(&memprops2);
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ } else if (ctx->is_integrated_gpu) {
+ // Include shared memory on iGPUs
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ }
+ }
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *free += mem_budget_props.heapBudget[i];
+ } else if (ctx->is_integrated_gpu) {
+ *free += mem_budget_props.heapBudget[i];
+ }
+ }
+ if (*total > 0 && *free > 0) {
+ return;
+ } else if (*total > 0) {
+ *free = *total;
+ return;
+ }
+ // else just report the physical memory
+ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*total = heap.size;
-
- if (membudget_supported && i < budgetprops.heapUsage.size()) {
- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
- } else {
- *free = heap.size;
- }
+ *free = heap.size;
break;
}
}
@@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
return std::string(pci_bus_id);
}
-//////////////////////////
-
-struct ggml_backend_vk_device_context {
- size_t device;
- std::string name;
- std::string description;
- bool is_integrated_gpu;
- std::string pci_bus_id;
- std::string id;
-};
+static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) {
+ if (id.empty()) return false;
+ unsigned int d = 0, b = 0, dev = 0, func = 0;
+ // Expected format: dddd:bb:dd.f (all hex)
+ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func);
+ if (n < 4) return false;
+ if (domain) *domain = (int) d;
+ if (bus) *bus = (int) b;
+ if (device) *device = (int) dev;
+ return true;
+}
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
- ggml_backend_vk_get_device_memory(ctx->device, free, total);
+ ggml_backend_vk_get_device_memory(ctx, free, total);
}
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->description = ggml_backend_vk_device_get_description(dev);
props->id = ggml_backend_vk_device_get_id(dev);
props->type = ggml_backend_vk_device_get_type(dev);
- props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
+ props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
@@ -12564,6 +12633,16 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
+
+ props->compute_major = ctx->major;
+ props->compute_minor = ctx->minor;
+ props->driver_major = ctx->driver_major;
+ props->driver_minor = ctx->driver_minor;
+ props->integrated = ctx->is_integrated_gpu;
+ props->pci_bus_id = ctx->pci_bus_id;
+ props->pci_device_id = ctx->pci_device_id;
+ props->pci_domain_id = ctx->pci_domain_id;
+ props->library = GGML_VK_NAME;
}
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
+ std::vector<vk::PhysicalDevice> vk_devices = vk_instance.instance.enumeratePhysicalDevices();
+
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256];
@@ -13000,13 +13081,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->name = GGML_VK_NAME + std::to_string(i);
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
- ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+ ctx->pci_id = ggml_backend_vk_get_device_pci_id(i);
ctx->id = ggml_backend_vk_get_device_id(i);
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
/* .context = */ ctx,
});
+
+ // Gather additional information about the device
+ int dev_idx = vk_instance.device_indices[i];
+ vk::PhysicalDeviceProperties props1;
+ vk_devices[dev_idx].getProperties(&props1);
+ vk::PhysicalDeviceProperties2 props2;
+ vk::PhysicalDeviceIDProperties device_id_props;
+ vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props;
+ vk::PhysicalDeviceDriverProperties driver_props;
+ props2.pNext = &device_id_props;
+ device_id_props.pNext = &pci_bus_props;
+ pci_bus_props.pNext = &driver_props;
+ vk_devices[dev_idx].getProperties2(&props2);
+ std::ostringstream oss;
+ oss << std::hex << std::setfill('0');
+ oss << "GPU-";
+ int byteIdx = 0;
+ for (int i = 0; i < 16; ++i, ++byteIdx) {
+ oss << std::setw(2) << static_cast<int>(device_id_props.deviceUUID[i]);
+ if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) {
+ oss << '-';
+ }
+ }
+ ctx->uuid = oss.str();
+ ctx->pci_bus_id = pci_bus_props.pciBus;
+ ctx->pci_device_id = pci_bus_props.pciDevice;
+ ctx->pci_domain_id = pci_bus_props.pciDomain;
+ ctx->id = std::to_string(i);
+ ctx->major = 0;
+ ctx->minor = 0;
+ // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string
+ ctx->driver_major = 0;
+ ctx->driver_minor = 0;
}
initialized = true;
}
--
2.51.0

View File

@ -0,0 +1,30 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jeff Bolz <jbolz@nvidia.com>
Date: Fri, 3 Oct 2025 04:52:46 -0500
Subject: [PATCH] vulkan: Fix FA coopmat1 invalid array indexing (#16365)
When computing sinks, the cm1 shader was looping r from 0 to Br rather than
to rows_per_thread. I must have copied this from the scalar path (where it is
correct), and somehow it wasn't causing failures on current drivers.
---
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index e76dbb4de..0507df2d8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -358,8 +358,8 @@ void main() {
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
- [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
- float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
--
2.51.0.windows.1

View File

@ -567,6 +567,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
(gpus[0].Library == "cpu" && s.options.UseMMap == nil) ||
(gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
}

View File

@ -57,7 +57,8 @@ var initDevices = sync.OnceFunc(func() {
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
gpus = append(gpus, d)
}
@ -470,7 +471,9 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
// Mimic llama runner logs summarizing layers and memory
gpuLayers := 0
for layer := range maps.Values(b.layers) {
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
switch C.ggml_backend_dev_type(layer.d) {
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
gpuLayers++
}
}
@ -479,7 +482,8 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
switch C.ggml_backend_dev_type(b.output) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
slog.Info("offloading output layer to CPU")
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
slog.Info("offloading output layer to GPU")
gpuLayers++
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:

View File

@ -20,10 +20,13 @@ include /src/ggml-cuda/vendors/
include /src/ggml-cuda/template-instances/
include /src/ggml-hip/
include /src/ggml-metal/
include src/ggml-vulkan/
include src/ggml-vulkan/vulkan-shaders
include CMakeLists.txt
include *.[chm]
include *.cpp
include *.cu
include *.cuh
include *.metal
include *.comp
hide *

View File

@ -0,0 +1,200 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED)
function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
else()
find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
endif()
set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
endfunction()
# Function to test shader extension support
# Parameters:
# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product")
# TEST_SHADER_FILE - Path to the test shader file
# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result
function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)
execute_process(
COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error
)
if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*")
message(STATUS "${EXTENSION_NAME} not supported by glslc")
set(${RESULT_VARIABLE} OFF PARENT_SCOPE)
else()
message(STATUS "${EXTENSION_NAME} supported by glslc")
set(${RESULT_VARIABLE} ON PARENT_SCOPE)
add_compile_definitions(${RESULT_VARIABLE})
# Ensure the extension support is forwarded to vulkan-shaders-gen
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE)
endif()
endfunction()
if (Vulkan_FOUND)
message(STATUS "Vulkan found")
ggml_add_backend_library(ggml-vulkan
ggml-vulkan.cpp
../../include/ggml-vulkan.h
)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
# Test all shader extensions
test_shader_extension_support(
"GL_KHR_cooperative_matrix"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_NV_cooperative_matrix2"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_integer_dot_product"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_bfloat16"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
)
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
# Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
endif()
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
endif()
if (GGML_VULKAN_DEBUG)
add_compile_definitions(GGML_VULKAN_DEBUG)
endif()
if (GGML_VULKAN_MEMORY_DEBUG)
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
endif()
if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON)
endif()
if (GGML_VULKAN_VALIDATE)
add_compile_definitions(GGML_VULKAN_VALIDATE)
endif()
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
if (CMAKE_CROSSCOMPILING)
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
else()
detect_host_compiler()
if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
message(FATAL_ERROR "Host compiler not found")
else()
message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
endif()
else()
# For non-cross-compiling, use empty toolchain (use host compiler)
set(HOST_CMAKE_TOOLCHAIN_FILE "")
endif()
include(ExternalProject)
if (CMAKE_CROSSCOMPILING)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
endif()
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
-DCMAKE_INSTALL_BINDIR=.
-DCMAKE_BUILD_TYPE=$<CONFIG>
${VULKAN_SHADER_GEN_CMAKE_ARGS}
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
BUILD_ALWAYS TRUE
# NOTE: When DESTDIR is set using Makefile generators and
# "make install" triggers the build step, vulkan-shaders-gen
# would be installed into the DESTDIR prefix, so it is unset
# to ensure that does not happen.
INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
${CMAKE_COMMAND} --install . --config $<CONFIG>
)
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
# Because external projects do not provide source-level tracking,
# the vulkan-shaders-gen sources need to be explicitly added to
# ensure that changes will cascade into shader re-generation.
file(GLOB _ggml_vk_shaders_gen_sources
CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp"
"${_ggml_vk_input_dir}/*.h")
add_custom_command(
OUTPUT ${_ggml_vk_header}
${_ggml_vk_source}
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
--input-dir ${_ggml_vk_input_dir}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
--target-cpp ${_ggml_vk_source}
--no-clean
DEPENDS ${_ggml_vk_shader_files}
${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
else()
message(WARNING "Vulkan not found")
endif()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,31 @@
cmake_minimum_required(VERSION 3.19)
project("vulkan-shaders-gen" C CXX)
find_package (Threads REQUIRED)
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat glslc support")
endif()
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat2 glslc support")
endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
message(STATUS "Enabling dot glslc support")
endif()
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
message(STATUS "Enabling bfloat16 glslc support")
endif()
if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
message(STATUS "Enabling shader debug info")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)

View File

@ -0,0 +1,29 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
}
}

View File

@ -0,0 +1,69 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#include "types.comp"
#include "generic_binary_head.comp"
const uint num_threads = 256;
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif
void main() {
uint idx = get_idx();
uint orig_idx = idx;
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
FLOAT_TYPE sum_sq = 0;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
sum_sq += sum*sum;
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
idx += num_threads;
}
#if ADD_RMS
if (p.param3 != 0) {
// reduce the sum within each subgroup, then across subgroups
const uint NumSubgroups = num_threads / gl_SubgroupSize;
sum_sq = subgroupAdd(sum_sq);
if (gl_SubgroupInvocationID == 0) {
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
sum_sq += sumsh[gl_SubgroupID + s];
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
}
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
}
}
#endif
}

View File

@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.comp"
layout (push_constant) uniform parameter
{
uint ne0;
uint ne1;
uint s01;
uint s02;
uint s11;
uint s21;
} p;
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i1 = gl_WorkGroupID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i11 = data_c[i1 + i2 * p.s21];
const uint s1 = p.ne0;
const uint s2 = p.ne0 * p.ne1;
const uint d0 = i1 * s1 + i2 * s2;
const uint a0 = i1 * p.s01 + i2 * p.s02;
const uint b0 = i11 * p.s11;
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
}
}

View File

@ -0,0 +1,60 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define FLT_MAX 3.402823466e+38F
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
shared uint tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
if (row >= p.KY) {
return;
}
A_TYPE amax = -FLT_MAX;
uint acol = col;
if (col < p.KX) {
amax = data_a[row*p.KX + col];
}
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
A_TYPE val = data_a[row*p.KX + i];
if (val > amax) {
amax = val;
acol = i;
}
}
tmp[col] = acol;
tmpmax[col] = amax;
barrier();
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
if (col < s && col + s < p.KX) {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
}
}
barrier();
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
}
}

View File

@ -0,0 +1,79 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
#define ASC 0
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint order;
} p;
shared int dst_row[BLOCK_SIZE];
shared A_TYPE a_sh[BLOCK_SIZE];
void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx0] = dst_row[idx1];
dst_row[idx1] = tmp;
}
void argsort(bool needs_bounds_check) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
// initialize indices
dst_row[col] = col;
a_sh[col] = data_a[row_offset + col];
barrier();
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;
int sh_idx_0 = dst_row[idx_0];
int sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
swap(idx_0, idx_1);
}
barrier();
}
}
if (col < p.ncols) {
if (p.order == ASC) {
data_d[row_offset + col] = dst_row[col];
} else {
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
}
}
}
void main() {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
} else {
argsort(true);
}
}

View File

@ -0,0 +1,17 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}

View File

@ -0,0 +1,41 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const int dim = p.param3;
if (idx >= p.ne) {
return;
}
const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
const uint i2_offset = i2*p.ne21*p.ne20;
const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
uint o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
#else
if (is_src0) {
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
} else {
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
}
#endif
}

View File

@ -0,0 +1,49 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
#extension GL_EXT_control_flow_attributes : require
const uint num_threads = 128;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 4;
// fast path for when all four iterations are in-bounds
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
} else {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
}
}

View File

@ -0,0 +1,105 @@
#version 450
#include "types.comp"
layout (push_constant) uniform parameter
{
uint ne;
uint batches;
uint channels;
uint dst_w;
uint dst_h;
uint src_w;
uint src_h;
uint knl_w;
uint knl_h;
int stride_x;
int stride_y;
int pad_x;
int pad_y;
int dilation_x;
int dilation_y;
} p;
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
uint i0 = idx / p.dst_w;
uint dst_x = idx - i0 * p.dst_w;
uint i1 = i0 / p.dst_h;
uint dst_y = i0 - i1 * p.dst_h;
uint n = i1 / p.channels;
uint c = i1 - n * p.channels;
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
uint knl_i = c * p.knl_h * p.knl_w;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
sum = fma(v, k, sum);
}
}
return sum;
}
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
uint i0 = idx / p.channels;
uint c = idx - i0 * p.channels;
uint i1 = i0 / p.dst_w;
uint dst_x = i0 - i1 * p.dst_w;
uint n = i1 / p.dst_h;
uint dst_y = i1 - n * p.dst_h;
uint src_i = n * p.channels * p.src_h * p.src_w;
uint src_row = p.src_w * p.channels;
uint knl_row = p.knl_w * p.channels;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
sum = fma(v, k, sum);
}
}
return sum;
}
void main() {
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
FLOAT_TYPE result =
#ifdef WHCN
conv_2d_dw_whcn(idx);
#else
conv_2d_dw_cwhn(idx);
#endif
dst_data[idx] = D_TYPE(result);
}

View File

@ -0,0 +1,349 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
#include "types.comp"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [W, H, Cin, N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, Cout, N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t Cout;
uint32_t Cin;
uint32_t N;
// Tensor spatial sizes: kernel, input, output
uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
// Parameters: stride, padding, dilation - 0=y, 1=x
uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
#ifdef TRANSPOSE
uint32_t s0mp; uint32_t s0L;
uint32_t s1mp; uint32_t s1L;
#endif
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
layout(constant_id = 6) const uint SHMEM_PAD = 4;
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.Cout;
uint32_t CRS = p.Cin * p.KH * p.KW;
uint32_t NPQ = p.N * p.OH * p.OW;
uint32_t n_elems_out = K * NPQ;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#ifdef COOPMAT2
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_numel = BS_K * BS_CRS;
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_K = BS_K / TS_K;
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=Cout
C=Cin
R,S=KH,KW
P,Q=OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
uint32_t Cin_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
#ifdef USE_COLLECTIVES
uint32_t cached_CRS_idx;
uint32_t cached_Cin_idx;
uint32_t cached_KH_idx;
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif
/* Load kernel to A_block: (BS_K x BS_CRS)*/
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
#ifdef TRANSPOSE
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
#else
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
#endif
float val = knl_data[knl_idx];
if (K_idx >= K || CRS_idx_a >= CRS) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
uint32_t CRS_idx_b;
uint32_t Cin_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
#ifdef USE_COLLECTIVES
if (use_collectives == 1) {
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif
#ifdef TRANSPOSE
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
#else
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
#endif
uint32_t src_idx =
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
float val = src_data[src_idx];
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
#ifdef TRANSPOSE
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
#endif
) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#ifdef COOPMAT2
coopMatPerElementNV(matC, matC, perElemOpStore);
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
}
}
}
}
#endif
}

View File

@ -0,0 +1,98 @@
#version 450
#include "types.comp"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
layout (push_constant) uniform parameter {
uint32_t Cout;
uint32_t Cin;
uint32_t K;
uint32_t L;
uint32_t KL;
uint32_t nb01;
uint32_t nb02;
uint32_t nb11;
uint32_t nb1;
int32_t s0;
} p;
uint32_t Cout_idx = gl_WorkGroupID.x;
const uint32_t bs = gl_WorkGroupSize.x;
uint32_t tid = gl_LocalInvocationID.x;
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
uint32_t tmp_len = bs*p.s0+p.K;
shared D_TYPE tmp[4096];
uint splitWork(uint workSize){
return (bs + workSize -1) / bs;
}
void main(){
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx < tmp_len){
tmp[idx] = 0.0;
}
}
uint32_t L_blocks = splitWork(p.L);
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
if(L_block_id > 0){
barrier();
// Shift values in tmp to the current processing window
for(int i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx >= bs*p.s0 && idx < tmp_len){
tmp[idx-bs*p.s0] = tmp[idx];
tmp[idx] = 0.0;
}else if(idx >= p.K && idx < bs*p.s0){
tmp[idx] = 0.0;
}
}
}
barrier();
// Save contributions of the block to tmp
uint32_t L_idx = L_block_id*bs + tid;
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
D_TYPE dp = 0.0;
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
if(L_idx < p.L){
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
dp = fma(elemKrn, elemInp, dp);
}
}
tmp[tid*p.s0 + K_idx] += dp;
barrier();
}
// Save the computed values except the last block that can have different size
uint32_t KLb_idx = L_block_id*bs*p.s0;
if(L_block_id < L_blocks-1){
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
uint32_t sh_idx = p.s0*tid+s0_idx;
uint32_t KL_idx = KLb_idx+sh_idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
}
}
}
}
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
}
}
}

View File

@ -0,0 +1,23 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
#endif
}

View File

@ -0,0 +1,51 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = get_doffset() + dst_idx(idx);
uint src_idx = src0_idx_quant(idx, QUANT_K);
const uint a_offset = 0;
const uint ib = src_idx;
const vec2 dm = get_dm(ib, a_offset);
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
vec4 v = dequantize4(ib, j / QUANT_R, a_offset);
v = v * dm.x + vec4(dm.y);
#if QUANT_R == 2
data_d[dst_idx + j/2 + 0] = v[0];
data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];
data_d[dst_idx + j/2 + 1] = v[2];
data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];
#else
data_d[dst_idx + j + 0] = v[0];
data_d[dst_idx + j + 1] = v[1];
data_d[dst_idx + j + 2] = v[2];
data_d[dst_idx + j + 3] = v[3];
#endif
}
}

View File

@ -0,0 +1,296 @@
#version 450
#include "rte.comp"
#include "types.comp"
#if defined(SET_ROWS) && QUANT_K == 1
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 512;
#else
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 32;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
#if defined(SET_ROWS)
#include "generic_binary_head.comp"
layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
#if B_SIZE == 64
#define DATA_I_SWIZZLE .x
#else
#define DATA_I_SWIZZLE
#endif
#else
#include "generic_unary_head.comp"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#endif
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
const uint xi0 = min(15, int(x0 + 8.5));
const uint xi1 = min(15, int(x1 + 8.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q4_1)
void quantize(uint dst_idx, uint src_idx)
{
float vmin = 1.0/0.0;
float vmax = -vmin;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
const float v = data_s[src_idx + j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(vmin);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
const uint xi0 = min(15, int(x0 + 0.5));
const uint xi1 = min(15, int(x1 + 0.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q5_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
const uint xi0 = min(31, int(x0 + 16.5));
const uint xi1 = min(31, int(x1 + 16.5));
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
}
data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
}
#endif
#if defined(DATA_A_Q5_1)
void quantize(uint dst_idx, uint src_idx)
{
float min = data_s[src_idx + 0];
float max = min;
[[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
const float v = data_s[src_idx + j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = (d != 0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(min);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - min)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
const uint xi0 = uint(x0 + 0.5);
const uint xi1 = uint(x1 + 0.5);
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
}
data_q[dst_idx].qh = qh;
}
#endif
#if defined(DATA_A_Q8_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0; // absolute max
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
const float v = data_s[src_idx + j];
amax = max(amax, abs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
const float x0 = data_s[src_idx + j]*id;
data_q[dst_idx].qs[j] = int8_t(round(x0));
}
}
#endif
#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;
if (x >= kvalues_iq4nl[15]) return 15;
int ml = 0, mu = 15;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
}
return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
}
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = (d != 0.0) ? 1.0/d : 0.0;
float sumqx = 0, sumq2 = 0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
const uint xi0 = best_index(x0);
const uint xi1 = best_index(x1);
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
}
#endif
#if defined(DATA_A_BF16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
}
#endif
#if defined(SET_ROWS)
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
uint i12 = fastmod(i03, p.ne12);
uint i11 = fastmod(i02, p.ne11);
uint i10 = i01;
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
quantize(dst_idx, src0_idx);
}
#else
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = dst_idx_quant(idx, QUANT_K);
uint src_idx = get_aoffset() + src0_idx(idx);
quantize(dst_idx, src_idx);
}
#endif

View File

@ -0,0 +1,17 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}

View File

@ -0,0 +1,31 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
#include "generic_head.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
const uint CHUNK_SIZE = 512;
void main() {
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
const uint col = gl_LocalInvocationID.x;
uint count = 0;
[[unroll]]
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
const uint idx = base + i + col;
if (idx >= p.KX) {
break;
}
count += uint(data_a[idx] == data_b[idx]);
}
atomicAdd(data_d[0], D_TYPE(count));
}

View File

@ -0,0 +1,20 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_GlobalInvocationID.x * 16;
if (i >= p.nel) {
return;
}
[[unroll]] for (uint l = 0; l < 16; l++) {
data_b[i + l] = D_TYPE(data_a[i + l]);
}
}

View File

@ -0,0 +1,616 @@
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#include "types.comp"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
#endif
#if defined(DATA_A_F16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
#endif
#if defined(DATA_A_BF16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
}
#endif
#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
}
#endif
#if defined(DATA_A_Q4_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
}
#endif
#if defined(DATA_A_Q5_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f);
}
#endif
#if defined(DATA_A_Q5_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = data_a[a_offset + ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y);
}
#endif
#if defined(DATA_A_Q8_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_IQ1_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const int i8 = int(iqs % 8);
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qs = data_a[a_offset + ib].qs[ib8];
const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
// Signed bitfield extract.
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
return dl * (vec2(gvec) + delta);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const int i8 = int(iqs % 8);
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qs = data_a[a_offset + ib].qs[ib8];
const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
// Signed bitfield extract.
const ivec4 gvec = ivec4(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2),
bitfieldExtract(grid, 2 * (i8 + 2), 2),
bitfieldExtract(grid, 2 * (i8 + 3), 2)
);
return dl * (vec4(gvec) + delta);
}
#endif
#if defined(DATA_A_IQ1_M)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib8 = iqs / 8;
const uint ib16 = iqs / 16;
const int i8 = int(iqs % 8);
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
// Signed bitfield extract.
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
return dl * (vec2(gvec) + delta);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib8 = iqs / 8;
const uint ib16 = iqs / 16;
const int i8 = int(iqs % 8);
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
// Signed bitfield extract.
const ivec4 gvec = ivec4(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2),
bitfieldExtract(grid, 2 * (i8 + 2), 2),
bitfieldExtract(grid, 2 * (i8 + 3), 2)
);
return dl * (vec4(gvec) + delta);
}
#endif
#if defined(DATA_A_IQ2_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
return db * vec2(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
return db * vec4(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ4_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec4 qs = u8vec4(
data_a[a_offset + ib].qs[iq + 0],
data_a[a_offset + ib].qs[iq + 1],
data_a[a_offset + ib].qs[iq + 2],
data_a[a_offset + ib].qs[iq + 3]
);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec4(
kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
}
#endif
#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}
#endif
#if defined(DATA_A_IQ1_M)
vec2 get_dm(uint ib, uint a_offset) {
const uint16_t[4] scales = data_a[a_offset + ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
return vec2(d, 0);
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
}
#endif
#if defined(DATA_A_Q2_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q3_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 64; // 0,1
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
const uint j = (iqs % 64) / 4; // 0..3
const uint is = iqs / 8; // 0..15
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
| (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[a_offset + ib].d) * float(us - 32);
return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q4_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m),
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q5_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q6_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 64; // 0,1
const uint b = (iqs % 64) / 32; // 0,1
const uint is_b = (iqs % 16) / 8; // 0,1
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uint is = 8 * n + qhshift + is_b; // 0..15
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]);
return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif

View File

@ -0,0 +1,720 @@
#include "types.comp"
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
qs >>= shift;
qs &= 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
block_q5_0 block;
};
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
block_q5_1 block;
};
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = bl.block.qh;
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs | qh) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
// Load 16b and select the byte for this element
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
float16_t ret = float16_t(qs) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
block_q2_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
block_q2_K_packed16 block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qs = (qs >> qsshift) & 0x0303;
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
block_q3_K block;
};
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint n = iqs / 128; // 0,1
const uint qsi = n * 32 + (iqs % 32); // 0..63
const uint hmi = (iqs % 32); // 0..31
const uint j = (iqs % 128) / 8; // 0..15
const uint is = iqs / 16; // 0..15
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
uint32_t scaleidx1 = is + 8 - (is/4)*4;
uint32_t scaleidx1shift = (is/4)*2;
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
const float16_t dl = bl.block.d * float16_t(us - 32);
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
block_q4_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
block_q4_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};
#if defined(IS_MUL_MM2)
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
// into shared memory and then process the whole tile using those scales.
// There is a fetch function that loads into private variables and then a store
// function that stores into shared memory.
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
// the part that fetches from the structure (which has a different block layout).
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
const uint shAscales_stride = (BM + 2);
// 1 scale per 32 elements -> 8 scales per block, per row
shared vec2 shAscales[8 * shAscales_stride];
uvec4 row_v;
#endif
#if defined(DATA_A_Q4_K)
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q4_k_packed128[block_index].q4k[0];
}
}
#endif
#if defined(DATA_A_Q5_K)
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q5_k_packed128[block_index].q5k[0];
}
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
void store_scalesQ4_K(uint tid)
{
barrier();
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
uint is = idx + is_start;
uvec4 v = row_v;
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
}
barrier();
}
#endif
#endif
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q4k[0];
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
#endif
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float ret = d * float(qs) - m;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
block_q5_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
block_q5_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
block_q5_K_packed128 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
#endif
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = ((qh >> is) & 0x101) << 4;
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs | qh)[idx & 1];
float ret = d * float(qs) - m;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
block_q6_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
block_q6_K_packed16 block;
};
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x40) >> 6; // 0,1
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint is = (idx & 0xF0) >> 4; // 0..15
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
ql = (ql >> (b * 4)) & 0x0F0F;
uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qh = ((qh >> qhshift) & 0x0303) << 4;
int q = unpack8(ql | qh)[idx & 1];
float16_t ret = dscale * float16_t(q - 32);
return ret;
}
#if defined(DATA_A_IQ1_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
block_iq1_s block;
};
float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5;
const uint ib8 = (idx & 0xF8) >> 3;
const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
return ret;
}
#endif
#if defined(DATA_A_IQ1_M)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
block_iq1_m block;
};
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
block_iq1_m_packed64 block;
};
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const uint idx = coordInBlock[1];
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
const uint ib8 = (idx & 0xF8) >> 3;
const uint ib16 = (idx & 0xF0) >> 4;
const int i8 = int(idx % 8);
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
return ret;
}
#endif
#if defined(DATA_A_IQ2_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
block_iq2_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
block_iq2_xxs_packed16 block;
};
float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0x18) >> 3; // 0..3
const uint iqs = 8 * ib32 + ib8;
const uint qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
sign |= bitCount(sign) << 7;
uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
#if defined(DATA_A_IQ2_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
block_iq2_xs block;
};
float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint is = (idx & 0xE0) >> 5; // 0..8
const uint sshift = (idx & 0x10) >> 2; // 0,4
const uint iqs = (idx & 0xF8) >> 3; // 0..63
const uint16_t qs = bl.block.qs[iqs];
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
uint sign = uint(qs >> 9);
sign |= bitCount(sign) << 7;
uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
#if defined(DATA_A_IQ2_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
block_iq2_s block;
};
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
const uint qhshift = 2 * (ib8 % 4);
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib32];
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
const float d = float(bl.block.d);
const float db = d * 0.25 * (0.5 + scale);
const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ3_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
block_iq3_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
block_iq3_xxs_packed16 block;
};
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
uint idx = coordInBlock[1];
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint signs = pack32(u16vec2(
bl16.block.qs[is/2+0],
bl16.block.qs[is/2+1]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ3_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
block_iq3_s block;
};
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint iqh = (idx & 0xE0) >> 5;
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint qh = bl.block.qh[iqh];
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
const uint scale = bl.block.scales[iqs / 16];
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ4_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
block_iq4_xs block;
};
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 16) >> 2;
const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
return ret;
}
#endif
#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
};
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
return ret;
}
#endif
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1
#elif defined(DATA_A_Q5_0)
#define dequantFuncA dequantFuncQ5_0
#elif defined(DATA_A_Q5_1)
#define dequantFuncA dequantFuncQ5_1
#elif defined(DATA_A_Q8_0)
#define dequantFuncA dequantFuncQ8_0
#elif defined(DATA_A_Q2_K)
#define dequantFuncA dequantFuncQ2_K
#elif defined(DATA_A_Q3_K)
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#define fetch_scales fetch_scalesQ4_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#define fetch_scales fetch_scalesQ5_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ1_S)
#define dequantFuncA dequantFuncIQ1_S
#elif defined(DATA_A_IQ1_M)
#define dequantFuncA dequantFuncIQ1_M
#elif defined(DATA_A_IQ2_XXS)
#define dequantFuncA dequantFuncIQ2_XXS
#elif defined(DATA_A_IQ2_XS)
#define dequantFuncA dequantFuncIQ2_XS
#elif defined(DATA_A_IQ2_S)
#define dequantFuncA dequantFuncIQ2_S
#elif defined(DATA_A_IQ3_XXS)
#define dequantFuncA dequantFuncIQ3_XXS
#elif defined(DATA_A_IQ3_S)
#define dequantFuncA dequantFuncIQ3_S
#elif defined(DATA_A_IQ4_XS)
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#endif

View File

@ -0,0 +1,13 @@
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint M;
uint K;
uint stride_a;
uint stride_b;
uint nel;
} p;
#include "types.comp"

View File

@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint ib64 = ib32 / 2;
const uint b_idx = 256 * ib + 32 * ib32;
const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint sc = data_a[ib].scales[ib64];
[[unroll]] for (int l = 0; l < 4; ++l) {
const uint ib16 = 2 * ib32 + l / 2;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
const uint qs = data_a[ib].qs[4 * ib32 + l];
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
[[unroll]] for (int j = 0; j < 8; ++j) {
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
}
}
}

View File

@ -0,0 +1,35 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
uint qh = data_a[ib].qh[ib32];
const float d = float(data_a[ib].d);
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qs = data_a[ib].qs[4 * ib32 + l];
const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
[[unroll]] for (int j = 0; j < 8; ++j) {
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
}
}
}

View File

@ -0,0 +1,44 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
uint qh = data_a[ib].qh[ib32];
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint qs = data_a[ib].qs[4 * ib32 + l];
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
qs |= (qh << (8 - 2 * l)) & 0x300;
const uvec2 grid = iq2s_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@ -0,0 +1,43 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint16_t qs = data_a[ib].qs[4 * ib32 + l];
const uint sign7 = qs >> 9;
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uvec2 grid = iq2xs_grid[qs & 511];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@ -0,0 +1,49 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[8*is + 4],
data_a[ib].qs[8*is + 5],
data_a[ib].qs[8*is + 6],
data_a[ib].qs[8*is + 7]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uint qs = data_a[ib].qs[8 * is + l];
const uvec2 grid = iq2xxs_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@ -0,0 +1,40 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale nibble.
// Each block contains 4 scale bytes (8 scales) for 256 output values.
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
uint qh = data_a[ib].qh[is];
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = 8 * is + l;
const uint qs = data_a[ib].qs[iqs];
const uint gidx = qs | ((qh << (8 - l)) & 256);
const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));
const u8vec4 grid = unpack8(iq3s_grid[gidx]);
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
}
}

View File

@ -0,0 +1,51 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// 8 threads handle 1 superblock
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const uint s_idx = QUANT_K / 4 + 4 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[s_idx + 0],
data_a[ib].qs[s_idx + 1],
data_a[ib].qs[s_idx + 2],
data_a[ib].qs[s_idx + 3]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.5;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
// Restore parity bit.
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint qs0 = data_a[ib].qs[8 * is + 2 * l];
const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];
const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);
const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@ -0,0 +1,32 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = float(data_a[ib].d);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@ -0,0 +1,34 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (1 scale and 32 quantized values)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const float d = float(data_a[ib].d);
// Scales are 6 bits
const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
| (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
const float dl = d * (int(scale) - 32);
const uint b_idx = 256 * ib + 32 * ib32;
const uint q_idx = 16 * ib32;
[[unroll]] for (uint l = 0; l < 16; ++l) {
data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@ -0,0 +1,32 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@ -0,0 +1,34 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint ip = tid / 32;
const uint il = tid - 32 * ip;
const uint is = 8 * ip + il / 16;
const uint y_idx = i * QUANT_K + 128 * ip + il;
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
}
}

View File

@ -0,0 +1,42 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.nel / QUANT_K) {
return;
}
const uint r = gl_LocalInvocationID.x / 4;
const uint tid = r / 2;
const uint is0 = r % 2;
const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
const uint n = tid / 4;
const uint j = tid - 4*n;
const uint8_t m = uint8_t(1 << (4*n + j));
const uint is = 8*n + 2*j + is0;
const uint shift = 2*j;
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
(data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
const uint qs_idx = 32*n;
for (uint l = l0; l < l0 + 4; ++l) {
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
}
}
}

View File

@ -0,0 +1,30 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = float(data_a[ib].d);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));
data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f));
}
}

View File

@ -0,0 +1,32 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
}
}

View File

@ -0,0 +1,68 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint il = tid / 8;
const uint ir = tid % 8;
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
uint scidx0 = (is < 4) ? is : (is + 4);
uint scidx1 = (is < 4) ? is : (is - 4);
uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint scidxshift1 = (is < 4) ? 0 : 2;
uint mbidx0 = is + 4;
uint mbidx1 = (is < 4) ? is + 4 : is;
uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint mbidxshift0 = (is < 4) ? 0 : 4;
uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint mbidxshift1 = (is < 4) ? 0 : 2;
uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * mbyte;
scidx0 = (is < 4) ? is + 1 : (is + 5);
scidx1 = (is < 4) ? is + 1 : (is - 3);
scidxmask1 = (is < 4) ? 0x30 : 0xC0;
scidxshift1 = (is < 4) ? 0 : 2;
mbidx0 = is + 5;
mbidx1 = (is < 4) ? is + 5 : is + 1;
mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbidxshift0 = (is < 4) ? 0 : 4;
mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * mbyte;
[[unroll]] for (uint l = 0; l < n; ++l) {
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1);
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2);
}
}
}

View File

@ -0,0 +1,34 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
}
}

View File

@ -0,0 +1,35 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint qh = data_a[ib].qh;
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
}
}

View File

@ -0,0 +1,70 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint il = tid / 16;
const uint ir = tid % 16;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
const uint qh_idx = 2 * ir;
uint scidx0 = (is < 4) ? is : (is + 4);
uint scidx1 = (is < 4) ? is : (is - 4);
uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint scidxshift1 = (is < 4) ? 0 : 2;
uint mbidx0 = is + 4;
uint mbidx1 = (is < 4) ? is + 4 : is;
uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint mbidxshift0 = (is < 4) ? 0 : 4;
uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint mbidxshift1 = (is < 4) ? 0 : 2;
uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * mbyte;
scidx0 = (is < 4) ? is + 1 : (is + 5);
scidx1 = (is < 4) ? is + 1 : (is - 3);
scidxmask1 = (is < 4) ? 0x30 : 0xC0;
scidxshift1 = (is < 4) ? 0 : 2;
mbidx0 = is + 5;
mbidx1 = (is < 4) ? is + 5 : is + 1;
mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbidxshift0 = (is < 4) ? 0 : 4;
mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * mbyte;
const uint8_t hm1 = uint8_t(1 << (2 * il ));
const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
}
}

View File

@ -0,0 +1,33 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint ip = tid / 32;
const uint il = tid - 32 * ip;
const uint is = 8 * ip + il / 16;
const uint y_idx = i * QUANT_K + 128 * ip + il;
const uint ql_idx = 64 * ip + il;
const uint8_t qh = data_a[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
}
}

View File

@ -0,0 +1,31 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 16*il;
const float d = float(data_a[ib].d);
const uint q_idx = 16*il;
[[unroll]] for (uint l = 0; l < 16; l += 2) {
data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
}
}

View File

@ -0,0 +1,34 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter
{
uint ncols;
uint rows_per_channel;
uint n_past;
} p;
#include "types.comp"
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint col = gl_GlobalInvocationID.y;
const uint row = gl_GlobalInvocationID.x;
if (col >= p.ncols) {
return;
}
const uint i = row*p.ncols + col;
if (col > p.n_past + row % p.rows_per_channel) {
data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
} else {
data_d[i] = D_TYPE(data_a[i]);
}
}

View File

@ -0,0 +1,27 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}

View File

@ -0,0 +1,21 @@
#version 450
#include "rte.comp"
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(exp(float(data_a[i])));
}

View File

@ -0,0 +1,382 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.comp"
#include "flash_attn_base.comp"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][HSK / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
vec4 Of[Br][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
}
float Lf[Br], Mf[Br];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
float slope[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = 1.0;
}
// ALiBi
if (p.max_bias > 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = 0.0;
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
}
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
}
}
}
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
}
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if (!KV_bounds_check || j * Bc + c < KV) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
}
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
}
eMf[r] = exp(Moldf[r] - Mf[r]);
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf;
}
}
}
barrier();
}
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float rowmaxf, eMf;
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
}
barrier();
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
}
barrier();
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += tmpshv4[tid + s];
tmpshv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpshv4[d_tid];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
#endif
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@ -0,0 +1,202 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask_n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
return ACC_TYPE(data_s[h]);
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
{
N = p.N;
KV = p.KV;
i = gl_WorkGroupID.x;
split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
Tr = CEIL_DIV(N, Br);
start_j = split_k_index * p.split_kv / Bc;
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
iq3 = gl_WorkGroupID.z;
// broadcast factors
rk2 = p.neq2/p.nek2;
rk3 = p.neq3/p.nek3;
rv2 = p.neq2/p.nev2;
rv3 = p.neq3/p.nev3;
// k indices
ik3 = iq3 / rk3;
ik2 = iq2 / rk2;
// v indices
iv3 = iq3 / rv3;
iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
k_stride = p.nb11;
v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}

View File

@ -0,0 +1,416 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#include "types.comp"
#include "flash_attn_base.comp"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
#define tile_row(r) (row_tid * rows_per_thread + (r))
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
if ((HSK % 16) != 0) {
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
if (i + tid < Br * qstride) {
Qf[i + tid] = f16vec4(0);
}
}
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kshstride) {
ksh[i + tid] = f16vec4(0);
}
}
barrier();
}
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
// ALiBi
if (p.max_bias > 0.0f) {
if (tid < Br) {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
}
barrier();
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
uint coord = gl_SubgroupID * MatBc * sfshstride;
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if (!KV_bounds_check || j * Bc + c < KV) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
}
barrier();
}
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
barrier();
}
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
#endif
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@ -0,0 +1,304 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_EXT_null_initializer : enable
#include "types.comp"
#include "dequant_funcs_cm2.comp"
#include "flash_attn_base.comp"
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
// Replace matrix elements >= numRows or numCols with 'replace'
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
if (row >= numRows || col >= numCols) {
return replace;
}
return elem;
}
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
{
return exp(elem);
}
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
{
return max(elem0, elem1);
}
#if defined(BLOCK_SIZE)
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c < HSV) {
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if defined(BLOCK_SIZE)
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
k_stride &= ~7;
v_stride &= ~7;
#endif
m_stride &= ~7;
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
// ALiBi
if (p.max_bias > 0.0f) {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
// Clear padding elements to -inf, so they don't contribute to rowmax
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
coopMatPerElementNV(M, rowmax, Max, Mold);
coopMatPerElementNV(P, S - M, Exp);
coopMatPerElementNV(eM, Mold - M, Exp);
// Clear padding elements to 0, so they don't contribute to rowsum
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
}
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
// compute rowsum by multiplying by matrix of all ones.
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared
// across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
// resize M by using smear/reduce
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
// O, Ldiag, Mr all have the same type so all element locations match
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
ACC_TYPE sink = S[i];
ACC_TYPE ms = ACC_TYPE(1.0f);
ACC_TYPE vs = ACC_TYPE(1.0f);
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);
O[i] *= ms;
} else {
vs = exp(sink - Mr[i]);
}
Ldiag[i] = Ldiag[i]*ms + vs;
}
}
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
}
O = Ldiag*O;
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
}
}

View File

@ -0,0 +1,120 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) readonly buffer B {float data_s[];};
layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
uint sinks;
} p;
shared float tmpsh[BLOCK_SIZE];
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m);
}
// reduce across the workgroup
tmpsh[tid] = m_max;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
m_max = max(m_max, tmpsh[tid + s]);
tmpsh[tid] = m_max;
}
barrier();
}
m_max = tmpsh[0];
barrier();
// Compute L based on m_max
float L = 0;
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float l = data_a[l_offset + (k + tid) * lm_stride];
float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l;
}
// reduce across the workgroup
tmpsh[tid] = L;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
L += tmpsh[tid + s];
tmpsh[tid] = L;
}
barrier();
}
L = tmpsh[0];
float sink;
if (p.sinks != 0) {
sink = data_s[n];
float ms = 1.0f;
float vs = 1.0f;
if (sink > m_max) {
ms = exp(m_max - sink);
} else {
vs = exp(sink - m_max);
}
L = L*ms + vs;
}
L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
if (p.sinks != 0) {
if (sink > m_max) {
float ms = 1.0f;
ms = exp(m_max - sink);
O *= ms;
}
}
O *= L;
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
data_d[iq3 * D * N + D * n + d] = O;
}
}

View File

@ -0,0 +1,13 @@
#version 450
#include "glu_head.comp"
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
float op(float a, float b) {
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
}
#include "glu_main.comp"

View File

@ -0,0 +1,27 @@
#version 450
#include "glu_head.comp"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
float op(float a, float b) {
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
return 0.5f * a * (1.0f + erf_approx) * b;
}
#include "glu_main.comp"

View File

@ -0,0 +1,11 @@
#version 450
#include "glu_head.comp"
const float GELU_QUICK_COEF = -1.702f;
float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
#include "glu_main.comp"

View File

@ -0,0 +1,25 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float xi = float(data_a[i]);
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
}

View File

@ -0,0 +1,39 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float a = float(data_a[i]);
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
}

View File

@ -0,0 +1,23 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_QUICK_COEF = -1.702f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
}

View File

@ -0,0 +1,51 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
#include "utils.comp"
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint misalign_offsets;
float param1; float param2; int param3;
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
layout(constant_id = 0) const bool norepeat = false;
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
}
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
if (norepeat) {
return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
} else {
return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
}
}
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
}

View File

@ -0,0 +1,9 @@
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint KX;
uint KY;
float param1;
float param2;
} p;

View File

@ -0,0 +1,76 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint misalign_offsets;
float param1; float param2;
uint ne0_012mp; uint ne0_012L;
uint ne0_01mp; uint ne0_01L;
uint ne0_0mp; uint ne0_0L;
uint ne1_012mp; uint ne1_012L;
uint ne1_01mp; uint ne1_01L;
uint ne1_0mp; uint ne1_0L;
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
uint src0_idx(uint idx) {
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
uint dst_idx(uint idx) {
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
}
uint src0_idx_quant(uint idx, uint qk) {
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00;
}
uint dst_idx_quant(uint idx, uint qk) {
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;
}

View File

@ -0,0 +1,42 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = gl_GlobalInvocationID.x;
if (i00 >= p.ne00) {
return;
}
uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(v);
#else
data_d[d_offset + i00] = D_TYPE(v);
#endif
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
}

View File

@ -0,0 +1,51 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
#include "generic_binary_head.comp"
#include "dequant_funcs.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
if (i00 >= p.ne00) {
return;
}
uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint ib = a_offset + i00/QUANT_K; // block index
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
vec2 v = dequantize(ib, iqs, 0);
const vec2 dm = get_dm(ib, 0);
v = v * dm.x + dm.y;
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
}

View File

@ -0,0 +1,19 @@
#extension GL_EXT_shader_16bit_storage : require
#include "rte.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter
{
uint N;
uint ne00;
uint ne20;
uint mode;
float alpha;
float limit;
} p;

View File

@ -0,0 +1,29 @@
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.N) {
return;
}
const uint row = i / p.ne20;
const uint col = i - row * p.ne20;
if (p.mode == 0) {
// Default
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
} else if (p.mode == 1) {
// Swapped
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
} else {
// Split
const uint idx = row * p.ne00 + col;
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
}
}

View File

@ -0,0 +1,66 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared float tmp[BLOCK_SIZE];
void main() {
const uint group_size = p.KX;
const float eps = p.param1;
const uint tid = gl_LocalInvocationID.x;
const uint start = gl_WorkGroupID.x * group_size + tid;
const uint end = (gl_WorkGroupID.x + 1) * group_size;
tmp[tid] = 0.0f;
// Calculate mean
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
tmp[tid] += float(data_a[col]);
}
// tmp up partial tmps and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
const float mean = tmp[0] / group_size;
barrier();
tmp[tid] = 0.0f;
// Calculate variance
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
const float xi = float(data_a[col]) - mean;
data_d[col] = D_TYPE(xi);
tmp[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
const float variance = tmp[0] / group_size;
const float scale = inversesqrt(variance + eps);
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
data_d[col] *= D_TYPE(scale);
}
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}

View File

@ -0,0 +1,104 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
#include "types.comp"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint batch_offset; uint offset_delta;
uint IC;
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint pelements;
uint CHW;
int s0; int s1;
int p0; int p1;
int d0; int d1;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
const uint NUM_ITER = 512 / BLOCK_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
const uint base_linear_idx = gidx * NUM_ITER;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
A_TYPE values[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
#endif
}
}

View File

@ -0,0 +1,126 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "rte.comp"
#include "types.comp"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint32_t i = gl_GlobalInvocationID.x;
uint32_t nb10 = p.nb10;
uint32_t nb11 = p.nb11;
uint32_t nb12 = p.nb12;
uint32_t nb13 = p.nb13;
uint32_t s0 = p.s0;
uint32_t s1 = p.s1;
uint32_t s2 = p.s2;
uint32_t p0 = p.p0;
uint32_t p1 = p.p1;
uint32_t p2 = p.p2;
uint32_t d0 = p.d0;
uint32_t d1 = p.d1;
uint32_t d2 = p.d2;
uint32_t IW = p.IW;
uint32_t IH = p.IH;
uint32_t ID = p.ID;
uint32_t IC = p.IC;
uint32_t KW = p.KW;
uint32_t OH = p.OH;
uint32_t KD_KH_KW = p.KD_KH_KW;
uint32_t KH_KW = p.KH_KW;
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
uint32_t N_OD_OH = p.N_OD_OH;
uint32_t OD_OH = p.OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
if (i >= IC_KD_KH_KW) {
return;
}
const uint32_t iic = i / KD_KH_KW;
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const uint32_t ikw = i % KW;
const uint32_t iow = gl_GlobalInvocationID.y;
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
const uint32_t in_ = iz / OD_OH;
const uint32_t iod = (iz - in_*OD_OH) / OH;
const uint32_t ioh = iz % OH;
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
if (iih >= IH || iiw >= IW || iid >= ID) {
dst_addr.d = D_TYPE(0.0f);
} else {
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#else
if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#endif
}
}

View File

@ -0,0 +1,41 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
sum[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum[tid] += sum[tid + s];
}
barrier();
}
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
}
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float val = float(data_a[i]);
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
}

View File

@ -0,0 +1,27 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}

View File

@ -0,0 +1,48 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
layout (push_constant) uniform parameter {
uint ne;
uint k_num;
} p;
void main() {
// Each invocation handles four consecutive components
const uint idx = gl_GlobalInvocationID.x * 4;
if (idx >= p.ne) {
return;
}
// Check if all four components are in bounds and aligned,
// then use vector loads
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
vec4 result = vec4(0.0f);
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a4[(i * p.ne + idx) / 4];
}
data_d4[idx / 4] = result;
} else {
[[unroll]] for (uint j = 0; j < 4; ++j) {
if (idx + j < p.ne) {
float result = 0.0f;
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a[i * p.ne + idx + j];
}
data_d[idx + j] = result;
}
}
}
}

View File

@ -0,0 +1,169 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif
uint a_offset, b_offset, d_offset, y_offset;
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
#if K_PER_ITER == 8
#if QUANT_R == 2
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
// accumulate it. We still fetch a pair of elements for A, which is fine for
// quantized formats since they'll be within the same block. We should
// probably skip fetching the second element for F16/F32, but as of now we
// still do.
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
FLOAT_TYPE b0 = 0, b1 = 0;
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
if (!OOB) {
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
}
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = (ibi + col)/QUANT_K; // block index
ibi += p.ncols;
#if K_PER_ITER == 8
vec4 v = dequantize4(ib, iqs, a_offset);
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
const vec2 dm = get_dm(ib, a_offset);
if (dm.y != 0) { // quant has min component
v = v * dm.x + dm.y;
v2 = v2 * dm.x + dm.y;
}
// matrix multiplication
FLOAT_TYPE rowtmp = dot(bv0, v);
rowtmp += dot(bv1, v2);
if (dm.y == 0)
rowtmp *= dm.x;
temp[j][n] += rowtmp;
#else
const vec2 v = dequantize(ib, iqs, a_offset);
// matrix multiplication
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
if (!OOB) {
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
}
#endif
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
// If the K dimension is odd, we need lastiter==true on the last iteration
// so OOB is computed correctly. Skip some unrolling to make that happen.
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,182 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#endif
#ifdef MUL_MAT_ID
#define EXPERT_COUNT 8
#endif
#include "types.comp"
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
#endif
#include "dequant_funcs.comp"
layout (push_constant) uniform parameter
{
uint ncols;
uint stride_a;
uint stride_b;
uint stride_d;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint nei0;
uint ne11;
#else
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
#endif
#ifndef MUL_MAT_ID
uint batch_idx_a = 0;
if (batch_idx != 0) {
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
batch_idx_a = i03 * p.ne02 + i02;
}
#else
const uint expert_id = data_ids[expert_idx];
#endif
a_offset =
#ifdef MUL_MAT_ID
expert_id * p.batch_stride_a;
#else
batch_idx_a * p.batch_stride_a;
#endif
b_offset =
#ifdef MUL_MAT_ID
(expert_idx % p.ne11) * p.stride_b;
#else
batch_idx * p.batch_stride_b;
#endif
d_offset =
#ifdef MUL_MAT_ID
expert_idx * p.stride_d;
#else
batch_idx * p.batch_stride_d;
#endif
}
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1;
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
}
#else
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
// subgroupAdd is probably faster on devices that support it,
// particularly when the workgroup has more than one subgroup
#if USE_SUBGROUP_ADD
// sum up partial sums within a subgroup
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}
// Go through shared memory to sum partials across subgroups
if (gl_SubgroupInvocationID == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][gl_SubgroupID] = temp[j][n];
}
}
}
barrier();
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
#else
// sum up partial sums and write back result
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] = temp[j][n];
}
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
}
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
}
}
}
#endif
}
#endif

View File

@ -0,0 +1,82 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint16_t[4] scales = data_a[ibi].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
const uint qs = data_a[ibi].qs[4 * ib32 + l];
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int k = 0; k < 4; ++k) {
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
}
temp[j][n] = fma(dl, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,79 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint qh = data_a[ibi].qh[ib32];
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qs = data_a[ibi].qs[4 * ib32 + l];
const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int k = 0; k < 4; ++k) {
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
}
temp[j][n] = fma(dl, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,90 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
const uint qh = data_a[ibi].qh[ib32];
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint8_t sign = sign16[l];
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);
const uvec2 grid = iq2s_grid[qs];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,87 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[2 * itid + l];
const uint sign = qs >> 9;
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,87 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(
data_a_packed16[ibi].qs[4 * ib32 + 2],
data_a_packed16[ibi].qs[4 * ib32 + 3]));
const float db = d * 0.25 * (0.5 + (signscale >> 28));
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l];
const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x));
const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,90 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const float dscale = d * (1 + 2 * scale);
const uint qh = data_a[ibi].qh[ib32];
FLOAT_TYPE sum[NUM_COLS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
sum[j] = 0.0;
}
[[unroll]] for (uint l = 0; l < 4; ++l) {
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
const uint sign = data_a[ibi].signs[4 * ib32 + l];
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
sum[j] =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
sum[j]))))))));
}
}
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
temp[j][n] = fma(dscale, sum[j], temp[j][n]);
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,88 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(
data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32],
data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1]));
const float db = d * 0.5 * (0.5 + (signscale >> 28));
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l];
const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1];
const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0]));
const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1]));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,122 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#define BLOCK_SIZE 32
#define FLOAT_TYPE float
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (push_constant) uniform parameter
{
uint ncols_x;
uint nrows_x;
uint row_stride_x;
uint channel_stride_x;
uint channel_stride_y;
uint channel_x_divisor;
uint ne12;
uint b_offset;
uint d_offset;
uint nb03;
uint nb13;
uint nb23;
} p;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint i3 = gl_WorkGroupID.x;
const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
FLOAT_TYPE temp = 0.0f;
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
// Unroll 2x and do vec4 loads if aligned
const uint unroll_count = 2;
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
}
// do vec4 loads if aligned
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
const uint row_y = col_x;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
col_x0 += BLOCK_SIZE;
}
}
tmp[tid] = temp;
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
if (tid == 0) {
dst[idst] = tmp[0];
}
}

View File

@ -0,0 +1,154 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#define FLOAT_TYPE float
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
layout(constant_id = 1) const uint gqa_ratio = 1;
layout (push_constant) uniform parameter
{
uint ncols_x;
uint nrows_x;
uint nchannels_x;
uint nchannels_y;
uint b_offset;
uint d_offset;
} p;
#if !USE_SUBGROUP_ADD
shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
uint channel, channel_x;
// When gqa_ratio > 1, each invocation does multiple rows.
// The row in the A matrix is starting from channel / gqa_ratio and the
// rows in the B matrix are [channel, channel+gqa_ratio).
// When gpa_ratio is 1, each invocation does one row.
if (gqa_ratio > 1) {
channel_x = gl_GlobalInvocationID.z;
channel = channel_x * gqa_ratio;
} else {
channel = gl_GlobalInvocationID.z;
channel_x = channel / (p.nchannels_y / p.nchannels_x);;
}
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
FLOAT_TYPE temp[8];
[[unroll]] for (uint i = 0; i < 8; ++i) {
temp[i] = FLOAT_TYPE(0.0f);
}
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
// Use vec4 loads if aligned
if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
vec4 bv4 = data_b_v4[iy / 4];
temp[c] += dot(av4, bv4);
}
col_x0 += 3*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
const uint row_y = col_x;
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
}
}
}
#if USE_SUBGROUP_ADD
// reduce vec4 at a time
vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
t = subgroupAdd(t);
temp[0] = t[0];
temp[1] = t[1];
temp[2] = t[2];
temp[3] = t[3];
if (gqa_ratio > 4) {
t = vec4(temp[4], temp[5], temp[6], temp[7]);
t = subgroupAdd(t);
temp[4] = t[0];
temp[5] = t[1];
temp[6] = t[2];
temp[7] = t[3];
}
#else
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
tmp[c][tid] = temp[c];
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] += tmp[c][tid + s];
tmp[c][tid] = temp[c];
}
}
barrier();
}
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] = tmp[c][tid];
}
#endif
if (tid == 0) {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
dst[idst] = temp[c];
}
}
}

View File

@ -0,0 +1,130 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();
if (i >= num_blocks_per_row)
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

Some files were not shown because too many files have changed in this diff Show More